[
  {
    "path": ".bazelrc",
    "content": "\n# Default options should come above this line\n\n# Put user-specific options in .bazelrc.user\ntry-import %workspace%/.bazelrc.user\n"
  },
  {
    "path": ".gitignore",
    "content": "# editor files\n*.swp\n*~\n.vscode/\n.DS_Store\n\n# bazel\n/.bazelrc.user\n/bazel-*\n\n# python\n*.pyc\n*.pyo\n__pycache__\n*.whl\n.ipynb_checkpoints\n"
  },
  {
    "path": "BUILD",
    "content": "# Description: Tensorflow Estimator.\n\nlicenses([\"notice\"])  # Apache 2.0\n\nexports_files([\"LICENSE\"])\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "Want to contribute? Great! First, read this page (including the small print at the end).\n\n### Before you contribute\n\nBefore we can use your code, you must sign the\n[Google Individual Contributor License Agreement]\n(https://cla.developers.google.com/about/google-individual)\n(CLA), which you can do online. The CLA is necessary mainly because you own the\ncopyright to your changes, even after your contribution becomes part of our\ncodebase, so we need your permission to use and distribute your code. We also\nneed to be sure of various other things—for instance that you'll tell us if you\nknow that your code infringes on other people's patents. You don't have to sign\nthe CLA until after you've submitted your code for review and a member has\napproved it, but you must do it before we can put your code into our codebase.\nBefore you start working on a larger contribution, you should get in touch with\nus first through the issue tracker with your idea so that we can help out and\npossibly guide you. Coordinating up front makes it much easier to avoid\nfrustration later on.\n\n### Code reviews\n\nAll submissions, including submissions by project members, require review. We\nuse Github pull requests for this purpose.\n\n### The small print\n\nContributions made by corporations are covered by a different agreement than\nthe one above, the\n[Software Grant and Corporate Contributor License Agreement]\n(https://cla.developers.google.com/about/google-corporate).\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright 2018 The TensorFlow Authors.  All rights reserved.\n\n                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright 2017, The TensorFlow Authors.\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "-----------------\n| **`Documentation`** |\n|-----------------|\n| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/python/tf/estimator) |\n\nTensorFlow Estimator is a high-level TensorFlow API that greatly simplifies machine learning programming.\nEstimators encapsulate training, evaluation, prediction, and exporting for your model.\n\n## Getting Started\n\nSee our Estimator\n[getting started guide](https://www.tensorflow.org/guide/estimator) for an\nintroduction to the Estimator APIs.\n\n## Installation\n\n`tf.estimator` is installed when you install the TensorFlow pip package. See\n[Installing TensorFlow](https://www.tensorflow.org/install) for instructions.\n\n## Developing\n\nIf you want to build TensorFlow Estimator locally, you will need to\n[install Bazel](https://docs.bazel.build/versions/master/install.html) and\n[install TensorFlow](https://www.tensorflow.org/install/pip).\n\n```sh\n# To build TensorFlow Estimator whl file.\nbazel build //tensorflow_estimator/tools/pip_package:build_pip_package\nbazel-bin/tensorflow_estimator/tools/pip_package/build_pip_package /tmp/estimator_pip\n\n# To run all Estimator tests\nbazel test //tensorflow_estimator/...\n```\n\n## Contribution guidelines\n\nIf you want to contribute to TensorFlow Estimator, be sure to review the [contribution\nguidelines](CONTRIBUTING.md).\n\n**Note that this repository is included as a component of the main TensorFlow\npackage, and any issues encountered while using Estimators should be filed under\n[TensorFlow GitHub Issues](https://github.com/tensorflow/tensorflow/issues),\nas we do not separately track issues in this repository. You can link this\nrepository in any issues created as necessary.**\n\nPlease see\n[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions\nand discussion and please direct specific questions to\n[Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).\n\n## License\n\n[Apache License 2.0](LICENSE)\n"
  },
  {
    "path": "WORKSPACE",
    "content": "workspace(name = \"org_tensorflow_estimator\")\n\n# Use a custom python toolchain to make sure we always use the python binary\n# provided by PYTHON_BIN_PATH.\n# This is required due to https://github.com/bazelbuild/bazel/issues/7899,\n# because --python_path will not work since Bazel 0.27\nload(\"//third_party/py:python_configure.bzl\", \"python_configure\")\n\npython_configure(name = \"local_config_py_toolchain\")\n\nregister_toolchains(\"@local_config_py_toolchain//:py_toolchain\")\n"
  },
  {
    "path": "tensorflow_estimator/BUILD",
    "content": "# Placeholder: load py_library\n\n# Description: Tensorflow Estimator.\nload(\n    \"//tensorflow_estimator/python/estimator/api:api_gen.bzl\",\n    \"ESTIMATOR_API_INIT_FILES_V1\",\n    \"ESTIMATOR_API_INIT_FILES_V2\",\n    \"generate_apis\",\n)\n\nlicenses([\"notice\"])\n\npackage(default_visibility = [\"//tensorflow_estimator:internal\"])\n\nexports_files([\"LICENSE\"])\n\n# TODO(mikecase): Clean up. Remove all non estimator packages.\npackage_group(\n    name = \"internal\",\n    packages = [\n        \"//learning/brain/...\",\n        \"//learning/deepmind/research/...\",\n        \"//learning/tfx/models/uplift/estimators/...\",\n        \"//nlp/nlx/ads/expmatch/model/...\",\n        \"//nlp/nlx/common/query_bert/...\",\n        \"//nlp/nlx/i18n/pangloss/...\",\n        \"//tensorflow_estimator/...\",\n        \"//third_party/py/tensorflow_privacy/...\",\n        \"//third_party/tensorflow/python/estimator/...\",\n    ],\n)\n\n# This flag specifies whether Estimator 2.0 API should be built instead\n# of 1.* API. Note that Estimator 2.0 API is currently under development.\nconfig_setting(\n    name = \"api_version_2\",\n    define_values = {\"estimator_api_version\": \"2\"},\n)\n\nconfig_setting(\n    name = \"no_estimator_py_deps\",\n    define_values = {\"no_estimator_py_deps\": \"true\"},\n    visibility = [\"//visibility:public\"],\n)\n\npy_library(\n    name = \"tensorflow_estimator\",\n    srcs = [\n        \":root_init_gen\",\n        \":estimator_python_api_gen_compat_v1\",\n        \":estimator_python_api_gen_compat_v2\",\n        # Old API files. Delete once TensorFlow is updated to import from new location.\n        \"//tensorflow_estimator/python/estimator/api:estimator_python_api_gen\",\n        \"//tensorflow_estimator/python/estimator/api:estimator_python_api_gen_compat_v1\",\n        \"//tensorflow_estimator/python/estimator/api:estimator_python_api_gen_compat_v2\",\n    ],\n    srcs_version = \"PY3\",\n    visibility = [\n        \"//tensorflow_estimator:internal\",\n        \"//third_party/tensorflow/tools/docs/google:__subpackages__\",\n    ],\n    deps = [\n        \"//tensorflow_estimator/python/estimator:estimator_py\",\n    ],\n)\n\ngenrule(\n    name = \"root_init_gen\",\n    srcs = select({\n        \"api_version_2\": [\"_api/v2/v2.py\"],\n        \"//conditions:default\": [\"_api/v1/v1.py\"],\n    }),\n    outs = [\"__init__.py\"],\n    cmd = select({\n        \"api_version_2\": \"cp $(location :_api/v2/v2.py) $(OUTS)\",\n        \"//conditions:default\": \"cp $(location :_api/v1/v1.py) $(OUTS)\",\n    }),\n)\n\ngenerate_apis(\n    name = \"estimator_python_api_gen_compat_v1\",\n    api_version = 1,\n    output_dir = \"_api/v1/\",\n    output_files = ESTIMATOR_API_INIT_FILES_V1,\n    output_package = \"tensorflow_estimator._api.v1\",\n    root_file_name = \"v1.py\",\n)\n\ngenerate_apis(\n    name = \"estimator_python_api_gen_compat_v2\",\n    api_version = 2,\n    output_dir = \"_api/v2/\",\n    output_files = ESTIMATOR_API_INIT_FILES_V2,\n    output_package = \"tensorflow_estimator._api.v2\",\n    root_file_name = \"v2.py\",\n)\n"
  },
  {
    "path": "tensorflow_estimator/estimator.bzl",
    "content": "\"\"\"Estimator common skylark macros.\"\"\"\n\n# Macro to run Estimator py_tests against pip installation.\ndef py_test(deps = [], **kwargs):\n    native.py_test(\n        deps = select({\n            \"//conditions:default\": deps,\n            \"//tensorflow_estimator:no_estimator_py_deps\": [],\n        }),\n        **kwargs\n    )\n\ndef tpu_py_test(**kwargs):\n    # Skip the tpu test for Estimator oss.\n    pass\n\n# We are never indexing generated code in the OSS build, but still\n# return a select() for consistency.\ndef if_indexing_source_code(\n        if_true,  # @unused\n        if_false):\n    \"\"\"Return a select() on whether or not we are building for source code indexing.\"\"\"\n    return select({\n        \"//conditions:default\": if_false,\n    })\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/BUILD",
    "content": "# Placeholder: load py_library\nload(\"//tensorflow_estimator:estimator.bzl\", \"py_test\")\n\npackage(default_visibility = [\"//tensorflow_estimator:internal\"])\n\nlicenses([\"notice\"])\n\npy_test(\n    name = \"tf_estimator_doctest\",\n    srcs = [\"tf_estimator_doctest.py\"],\n    python_version = \"PY3\",\n    tags = [\n        \"no_oss_py2\",\n        \"noasan\",\n        \"nomsan\",\n        \"notsan\",\n    ],\n    deps = [\n        \":estimator_py\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n    ],\n)\n\npy_library(\n    name = \"estimator_py\",\n    srcs = [\n        \"estimator_lib.py\",\n    ],\n    srcs_version = \"PY3\",\n    visibility = [\"//visibility:public\"],\n    deps = [\n        \":base_head\",\n        \":baseline\",\n        \":basic_session_run_hooks\",\n        \":binary_class_head\",\n        \":checkpoint_converter\",\n        \":dnn\",\n        \":dnn_linear_combined\",\n        \":early_stopping\",\n        \":estimator\",\n        \":export\",\n        \":exporter\",\n        \":extenders\",\n        \":fake_summary_writer\",\n        \":function\",\n        \":hooks\",\n        \":inputs\",\n        \":keras\",\n        \":kmeans\",\n        \":linear\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":multi_class_head\",\n        \":multi_head\",\n        \":multi_label_head\",\n        \":parsing_utils\",\n        \":regression_head\",\n        \":rnn\",\n        \":run_config\",\n        \":saved_model_estimator\",\n        \":sequential_head\",\n        \":session_run_hook\",\n        \":training\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorboard_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n        \"//tensorflow_estimator/python/estimator/canned/timeseries:estimators\",\n        \"//tensorflow_estimator/python/estimator/tpu:tpu_estimator\",\n    ],\n)\n\npy_library(\n    name = \"exporter\",\n    srcs = [\"exporter.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \":gc\",\n        \":metric_keys\",\n        \":util\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"exporter_test\",\n    size = \"medium\",\n    srcs = [\"exporter_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":exporter\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"extenders\",\n    srcs = [\"extenders.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":estimator_export\",\n        \":mode_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"extenders_test\",\n    size = \"medium\",\n    srcs = [\"extenders_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\"notsan\"],  # b/62863147\n    deps = [\n        \":extenders\",\n        \":linear\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"gc\",\n    srcs = [\"gc.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"gc_test\",\n    size = \"small\",\n    srcs = [\"gc_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":gc\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"hooks\",\n    srcs = [\"hooks/hooks.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"hooks_test\",\n    srcs = [\"hooks/hooks_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_py\",\n        \":hooks\",\n        \"//tensorflow_estimator/python/estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"model_fn\",\n    srcs = [\"model_fn.py\"],\n    srcs_version = \"PY3\",\n    visibility = [\n        \"//tensorflow_estimator:internal\",\n        \"//third_party/tensorflow/python/tpu:__pkg__\",\n    ],\n    deps = [\n        \":estimator_export\",\n        \":mode_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"model_fn_test\",\n    size = \"small\",\n    srcs = [\"model_fn_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":export_output\",\n        \":model_fn\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"mode_keys\",\n    srcs = [\"mode_keys.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"training\",\n    srcs = [\"training.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":estimator_export\",\n        \":exporter\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"training_test\",\n    size = \"medium\",\n    srcs = [\"training_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"notap\",  # TODO(b/170896944): flaky, broken\n        \"notsan\",\n    ],\n    deps = [\n        \":dnn\",\n        \":estimator\",\n        \":exporter\",\n        \":inputs\",\n        \":run_config\",\n        \":training\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"run_config\",\n    srcs = [\"run_config.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"run_config_test\",\n    size = \"small\",\n    srcs = [\"run_config_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"baseline\",\n    srcs = [\"canned/baseline.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":estimator_export\",\n        \":head\",\n        \":head_utils\",\n        \":model_fn\",\n        \":optimizers\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"baseline_test\",\n    size = \"medium\",\n    srcs = [\"canned/baseline_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"noasan\",  # test flakily times out in asan mode.\n        \"notsan\",  # b/67510291\n        \"optonly\",  # flakily times out in fastbuild\n    ],\n    deps = [\n        \":baseline\",\n        \":estimator\",\n        \":export_export\",\n        \":metric_keys\",\n        \":numpy_io\",\n        \":pandas_io\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"baseline_test_v1\",\n    size = \"medium\",\n    srcs = [\"canned/v1/baseline_test_v1.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"noasan\",  # test flakily times out in asan mode.\n        \"notsan\",  # b/67510291\n        \"optonly\",  # flakily times out in fastbuild\n    ],\n    deps = [\n        \":baseline\",\n        \":estimator\",\n        \":export_export\",\n        \":metric_keys\",\n        \":numpy_io\",\n        \":pandas_io\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"baseline_estimator_test\",\n    size = \"medium\",\n    srcs = [\"canned/baseline_estimator_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"noasan\",  # test flakily times out in asan mode.\n        \"notsan\",  # b/67510291\n        \"optonly\",  # flakily times out in fastbuild\n    ],\n    deps = [\n        \":baseline\",\n        \":estimator\",\n        \":export_export\",\n        \":metric_keys\",\n        \":numpy_io\",\n        \":regression_head\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"baseline_estimator_test_v1\",\n    size = \"medium\",\n    srcs = [\"canned/v1/baseline_estimator_test_v1.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"noasan\",  # test flakily times out in asan mode.\n        \"notsan\",  # b/67510291\n        \"optonly\",  # flakily times out in fastbuild\n    ],\n    deps = [\n        \":baseline\",\n        \":estimator\",\n        \":export_export\",\n        \":metric_keys\",\n        \":numpy_io\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"kmeans\",\n    srcs = [\"canned/kmeans.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":estimator_export\",\n        \":head\",\n        \":model_fn\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"kmeans_test\",\n    size = \"medium\",\n    srcs = [\"canned/kmeans_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 8,\n    srcs_version = \"PY3\",\n    tags = [\n        \"notap\",  # TODO(b/170974352): Flaky timeout\n    ],\n    deps = [\n        \":inputs\",\n        \":kmeans\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"dnn\",\n    srcs = [\"canned/dnn.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":estimator_export\",\n        \":head\",\n        \":head_utils\",\n        \":mode_keys\",\n        \":optimizers\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"dnn_testing_utils\",\n    srcs = [\"canned/dnn_testing_utils.py\"],\n    srcs_version = \"PY3\",\n    visibility = [\"//visibility:public\"],\n    deps = [\n        \":estimator\",\n        \":head\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"dnn_testing_utils_v1\",\n    srcs = [\"canned/v1/dnn_testing_utils_v1.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":head\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"dnn_test_fc_v1_v1\",\n    size = \"medium\",\n    srcs = [\"canned/v1/dnn_test_fc_v1_v1.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",  # b/67510291\n    ],\n    deps = [\n        \":dnn\",\n        \":dnn_testing_utils_v1\",\n        \":export_export\",\n        \":numpy_io\",\n        \":pandas_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"dnn_test_fc_v2\",\n    size = \"medium\",\n    srcs = [\"canned/dnn_test_fc_v2.py\"],\n    python_version = \"PY3\",\n    shard_count = 8,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",  # b/67510291\n    ],\n    deps = [\n        \":dnn\",\n        \":dnn_testing_utils\",\n        \":export_export\",\n        \":numpy_io\",\n        \":pandas_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"dnn_test_fc_v2_v1\",\n    size = \"medium\",\n    srcs = [\"canned/v1/dnn_test_fc_v2_v1.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",  # b/67510291\n    ],\n    deps = [\n        \":dnn\",\n        \":dnn_testing_utils_v1\",\n        \":export_export\",\n        \":numpy_io\",\n        \":pandas_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"dnn_estimator_test\",\n    size = \"medium\",\n    srcs = [\"canned/dnn_estimator_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",\n        \"optonly\",  # times out http://b/79220679\n    ],\n    deps = [\n        \":dnn\",\n        \":dnn_testing_utils\",\n        \":export_export\",\n        \":multi_class_head\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"dnn_estimator_test_v1\",\n    size = \"medium\",\n    srcs = [\"canned/v1/dnn_estimator_test_v1.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",\n        \"optonly\",  # times out http://b/79220679\n    ],\n    deps = [\n        \":dnn\",\n        \":dnn_testing_utils_v1\",\n        \":export_export\",\n        \":head\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"dnn_linear_combined\",\n    srcs = [\"canned/dnn_linear_combined.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":dnn\",\n        \":estimator\",\n        \":estimator_export\",\n        \":head\",\n        \":head_utils\",\n        \":linear\",\n        \":model_fn\",\n        \":optimizers\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"dnn_linear_combined_estimator_test\",\n    size = \"medium\",\n    srcs = [\"canned/dnn_linear_combined_estimator_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 3,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",\n    ],\n    deps = [\n        \":dnn_linear_combined\",\n        \":dnn_testing_utils\",\n        \":export_export\",\n        \":linear_testing_utils\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"dnn_linear_combined_estimator_test_v1\",\n    size = \"medium\",\n    srcs = [\"canned/v1/dnn_linear_combined_estimator_test_v1.py\"],\n    python_version = \"PY3\",\n    shard_count = 3,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",\n    ],\n    deps = [\n        \":dnn_linear_combined\",\n        \":dnn_testing_utils_v1\",\n        \":export_export\",\n        \":head\",\n        \":linear_testing_utils_v1\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"dnn_linear_combined_test\",\n    size = \"medium\",\n    srcs = [\"canned/dnn_linear_combined_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 32,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_oss\",  # TODO(b/143323557)\n        \"no_pip\",\n        \"notsan\",  # TODO(b/67510291)\n    ],\n    deps = [\n        \":dnn_linear_combined\",\n        \":dnn_testing_utils\",\n        \":export_export\",\n        \":linear_testing_utils\",\n        \":numpy_io\",\n        \":pandas_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"dnn_linear_combined_test_v1\",\n    size = \"medium\",\n    srcs = [\"canned/v1/dnn_linear_combined_test_v1.py\"],\n    python_version = \"PY3\",\n    shard_count = 16,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",  # TODO(b/67510291)\n    ],\n    deps = [\n        \":dnn_linear_combined\",\n        \":dnn_testing_utils_v1\",\n        \":export_export\",\n        \":linear_testing_utils_v1\",\n        \":numpy_io\",\n        \":pandas_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"checkpoint_converter\",\n    srcs = [\"tools/checkpoint_converter.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"analytics_tools\",\n    srcs = [\"tools/analytics.py\"],\n    srcs_version = \"PY3\",\n    deps = [\"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\"],\n)\n\npy_test(\n    name = \"checkpoint_converter_test\",\n    srcs = [\"tools/checkpoint_converter_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    deps = [\n        \":checkpoint_converter\",\n        \":dnn\",\n        \":dnn_linear_combined\",\n        \":head\",\n        \":linear\",\n        \":numpy_io\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"util\",\n    srcs = [\n        \"util.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"util_test\",\n    srcs = [\"util_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\"notsan\"],  # b/67510291\n    deps = [\n        \":util\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"early_stopping\",\n    srcs = [\n        \"early_stopping.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \":export_export\",\n        \":model_fn\",\n        \":run_config\",\n        \":util\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"early_stopping_test\",\n    srcs = [\n        \"early_stopping_test.py\",\n    ],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"notap\",  # TODO(b/134928532): Reenable this test.\n    ],\n    deps = [\n        \":early_stopping\",\n        \"//tensorflow_estimator/python/estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n    ],\n)\n\npy_library(\n    name = \"estimator\",\n    srcs = [\n        \"estimator.py\",\n    ],\n    srcs_version = \"PY3\",\n    visibility = [\n        \"//tensorflow_estimator:internal\",\n        \"//third_party/tensorflow/python/tpu:__pkg__\",\n    ],\n    deps = [\n        \":estimator_export\",\n        \":export\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":run_config\",\n        \":util\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"estimator_test\",\n    srcs = [\"estimator_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\"notsan\"],  # b/67510291\n    deps = [\n        \":estimator\",\n        \":estimator_py\",\n        \":export\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":numpy_io\",\n        \":run_config\",\n        # Placeholder for an internal build dep disabling tf2 behavior\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"parsing_utils\",\n    srcs = [\n        \"canned/parsing_utils.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"parsing_utils_test\",\n    srcs = [\"canned/parsing_utils_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":parsing_utils\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"export_output\",\n    srcs = [\"export/export_output.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"export\",\n    srcs = [\n        \"export/export_lib.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":export_export\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"export_export\",\n    srcs = [\n        \"export/export.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \":util\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"export_test\",\n    size = \"small\",\n    srcs = [\"export/export_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":export_export\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"function\",\n    srcs = [\n        \"export/function.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":mode_keys\",\n        \":model_fn\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"function_test\",\n    size = \"small\",\n    srcs = [\"export/function_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":export\",\n        \":function\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"head\",\n    srcs = [\"canned/head.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":export_output\",\n        \":metric_keys\",\n        \":model_fn\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"head_test\",\n    size = \"medium\",\n    srcs = [\"canned/head_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"manual\",\n        \"no_pip\",\n        \"notap\",  # b/148804861\n    ],\n    deps = [\n        \":dnn_testing_utils_v1\",\n        \":head\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"head_utils\",\n    srcs = [\"head/head_utils.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":binary_class_head\",\n        \":multi_class_head\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"base_head\",\n    srcs = [\"head/base_head.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \":export_output\",\n        \":head\",\n        \":metric_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"base_head_test\",\n    size = \"small\",\n    srcs = [\"head/base_head_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":base_head_test_lib\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n    ],\n)\n\npy_library(\n    name = \"base_head_test_lib\",\n    testonly = True,\n    srcs = [\"head/base_head_test.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":base_head\",\n        \":binary_class_head\",\n        \":head_utils\",\n        \":mode_keys\",\n        \":model_fn\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"binary_class_head\",\n    srcs = [\"head/binary_class_head.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":base_head\",\n        \":estimator_export\",\n        \":export_output\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"binary_class_head_test\",\n    size = \"medium\",\n    srcs = [\"head/binary_class_head_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"manual\",\n        \"no_pip\",\n        \"notap\",  # b/148804861\n    ],\n    deps = [\n        \":binary_class_head\",\n        \":dnn\",\n        \":dnn_testing_utils\",\n        \":head_utils\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"multi_head\",\n    srcs = [\"head/multi_head.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":base_head\",\n        \":estimator_export\",\n        \":export_output\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"multi_head_test\",\n    size = \"medium\",\n    srcs = [\"head/multi_head_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    deps = [\n        \":head_utils\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":multi_head\",\n        \":multi_label_head\",\n        \":prediction_keys\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"multi_class_head\",\n    srcs = [\"head/multi_class_head.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":base_head\",\n        \":estimator_export\",\n        \":export_output\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"multi_class_head_test\",\n    size = \"medium\",\n    srcs = [\"head/multi_class_head_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_oss\",  # TODO(b/202525254): broken on TF 2.7\n    ],\n    deps = [\n        \":dnn\",\n        \":head_utils\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":multi_class_head\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"multi_label_head\",\n    srcs = [\"head/multi_label_head.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":base_head\",\n        \":estimator_export\",\n        \":export_output\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"multi_label_head_test\",\n    size = \"medium\",\n    srcs = [\"head/multi_label_head_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    deps = [\n        \":dnn\",\n        \":head_utils\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":multi_label_head\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"regression_head\",\n    srcs = [\"head/regression_head.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":base_head\",\n        \":estimator_export\",\n        \":export_output\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"regression_head_test\",\n    size = \"medium\",\n    srcs = [\"head/regression_head_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"manual\",\n        \"notap\",  # b/148804861\n    ],\n    deps = [\n        \":head_utils\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"sequential_head\",\n    srcs = [\"head/sequential_head.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":base_head\",\n        \":mode_keys\",\n        \":multi_head\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"sequential_head_test\",\n    size = \"medium\",\n    srcs = [\"head/sequential_head_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":binary_class_head\",\n        \":head_utils\",\n        \":metric_keys\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":multi_class_head\",\n        \":multi_head\",\n        \":prediction_keys\",\n        \":sequential_head\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"inputs\",\n    srcs = [\"inputs/inputs.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":numpy_io\",\n        \":pandas_io\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"linear\",\n    srcs = [\"canned/linear.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":binary_class_head\",\n        \":estimator\",\n        \":estimator_export\",\n        \":head\",\n        \":head_utils\",\n        \":optimizers\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n        \"//tensorflow_estimator/python/estimator/canned/linear_optimizer:sdca_ops_py\",\n    ],\n)\n\npy_library(\n    name = \"linear_testing_utils\",\n    srcs = [\"canned/linear_testing_utils.py\"],\n    srcs_version = \"PY3\",\n    visibility = [\"//visibility:public\"],\n    deps = [\n        \":estimator\",\n        \":export_export\",\n        \":linear\",\n        \":metric_keys\",\n        \":numpy_io\",\n        \":pandas_io\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"linear_testing_utils_v1\",\n    srcs = [\"canned/v1/linear_testing_utils_v1.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":export_export\",\n        \":linear\",\n        \":metric_keys\",\n        \":numpy_io\",\n        \":pandas_io\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"linear_estimator_test\",\n    size = \"medium\",\n    srcs = [\"canned/linear_estimator_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",\n    ],\n    deps = [\n        \":export_export\",\n        \":linear\",\n        \":linear_testing_utils\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \":regression_head\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"linear_estimator_test_v1\",\n    size = \"medium\",\n    srcs = [\"canned/v1/linear_estimator_test_v1.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",\n    ],\n    deps = [\n        \":export_export\",\n        \":head\",\n        \":linear\",\n        \":linear_testing_utils_v1\",\n        \":numpy_io\",\n        \":prediction_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"linear_test\",\n    size = \"medium\",\n    srcs = [\"canned/linear_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 8,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",  # b/67510291\n    ],\n    deps = [\n        \":linear\",\n        \":linear_testing_utils\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\nfilegroup(\n    name = \"vocabulary_testdata\",\n    srcs = [\n        \"canned/testdata/wire_vocabulary.txt\",\n    ],\n)\n\npy_test(\n    name = \"linear_model_test\",\n    size = \"medium\",\n    srcs = [\"canned/linear_model_test.py\"],\n    data = [\":vocabulary_testdata\"],\n    python_version = \"PY3\",\n    shard_count = 8,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_cuda_on_cpu_tap\",\n        \"no_pip\",\n        \"no_rocm\",\n        \"no_windows\",\n        \"notsan\",  # b/67510291\n    ],\n    deps = [\n        \":linear\",\n        \":linear_testing_utils\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"linear_test_v1\",\n    size = \"medium\",\n    srcs = [\"canned/v1/linear_test_v1.py\"],\n    python_version = \"PY3\",\n    shard_count = 8,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_pip\",\n        \"notsan\",  # b/67510291\n    ],\n    deps = [\n        \":linear\",\n        \":linear_testing_utils_v1\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"metric_keys\",\n    srcs = [\"canned/metric_keys.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":model_fn\",\n    ],\n)\n\npy_library(\n    name = \"numpy_io\",\n    srcs = [\"inputs/numpy_io.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \":inputs_queues\",\n    ],\n)\n\npy_test(\n    name = \"numpy_io_test\",\n    size = \"small\",\n    srcs = [\"inputs/numpy_io_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":numpy_io\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"optimizers\",\n    srcs = [\"canned/optimizers.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":util\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"optimizers_test\",\n    size = \"small\",\n    srcs = [\"canned/optimizers_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":optimizers\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"optimizers_test_v2\",\n    size = \"small\",\n    srcs = [\"canned/optimizers_test_v2.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":optimizers\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"object_checkpointing_test\",\n    size = \"medium\",\n    srcs = [\"object_checkpointing_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":model_fn\",\n        \":optimizers\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"pandas_io\",\n    srcs = [\"inputs/pandas_io.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \":inputs_queues\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n    ],\n)\n\npy_test(\n    name = \"pandas_io_test\",\n    size = \"small\",\n    srcs = [\"inputs/pandas_io_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":pandas_io\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"prediction_keys\",\n    srcs = [\"canned/prediction_keys.py\"],\n    srcs_version = \"PY3\",\n    visibility = [\"//visibility:public\"],\n    deps = [],\n)\n\npy_library(\n    name = \"inputs_queues\",\n    srcs = [\n        \"inputs/queues/__init__.py\",\n        \"inputs/queues/feeding_functions.py\",\n        \"inputs/queues/feeding_queue_runner.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"feeding_functions_test\",\n    size = \"small\",\n    srcs = [\n        \"inputs/queues/feeding_functions_test.py\",\n    ],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":inputs_queues\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"feeding_queue_runner_test\",\n    size = \"small\",\n    srcs = [\"inputs/queues/feeding_queue_runner_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":inputs_queues\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_pandas_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"keras\",\n    srcs = [\"keras_lib.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":export\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"keras_test\",\n    size = \"medium\",\n    srcs = [\"keras_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 8,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_windows\",\n        \"notsan\",  # b/67510291\n    ],\n    deps = [\n        \":export\",\n        \":keras\",\n        \":mode_keys\",\n        \":numpy_io\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_h5py_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"keras_premade_model_test\",\n    size = \"medium\",\n    srcs = [\"keras_premade_model_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    deps = [\n        \":export\",\n        \":keras\",\n        \":mode_keys\",\n        \":numpy_io\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_h5py_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"keras_distribute_strategy_test\",\n    srcs = [\"keras_distribute_strategy_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\"notsan\"],\n    deps = [\n        \":keras\",\n        \":run_config\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"saved_model_estimator\",\n    srcs = [\"canned/saved_model_estimator.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator\",\n        \":estimator_export\",\n        \":export\",\n        \":mode_keys\",\n        \":model_fn\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"saved_model_estimator_test\",\n    size = \"medium\",\n    srcs = [\"canned/saved_model_estimator_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"notsan\",\n    ],\n    deps = [\n        \":estimator\",\n        \":export\",\n        \":mode_keys\",\n        \":model_fn\",\n        \":saved_model_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"basic_session_run_hooks\",\n    srcs = [\"hooks/basic_session_run_hooks.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"basic_session_run_hooks_test\",\n    size = \"medium\",\n    srcs = [\"hooks/basic_session_run_hooks_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_py\",\n        \":fake_summary_writer\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n    ],\n)\n\npy_library(\n    name = \"session_run_hook\",\n    srcs = [\"hooks/session_run_hook.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"fake_summary_writer\",\n    srcs = [\n        \"hooks/fake_summary_writer.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"rnn\",\n    srcs = [\"canned/rnn.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":binary_class_head\",\n        \":estimator\",\n        \":estimator_export\",\n        \":multi_class_head\",\n        \":optimizers\",\n        \":sequential_head\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"rnn_test\",\n    size = \"medium\",\n    srcs = [\"canned/rnn_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_oss\",  # b/140934549\n        \"no_pip\",\n        \"noasan\",  # times out\n        \"notsan\",\n        \"optonly\",  # times out http://b/79220679\n    ],\n    deps = [\n        \":export\",\n        \":head\",\n        \":metric_keys\",\n        \":multi_class_head\",\n        \":numpy_io\",\n        \":parsing_utils\",\n        \":prediction_keys\",\n        \":rnn\",\n        \":sequential_head\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"estimator_export\",\n    srcs = [\"estimator_export.py\"],\n    srcs_version = \"PY3\",\n    visibility = [\"//tensorflow_estimator:internal\"],\n    deps = [\n        \":expect_tensorflow_installed\",\n        \":util\",\n    ],\n)\n\npy_test(\n    name = \"estimator_export_test\",\n    srcs = [\"estimator_export_test.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimator_export\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"expect_absl_installed\",\n    srcs_version = \"PY3\",\n    # This is a dummy rule used as a absl dependency in open-source.\n    # We expect absl to already be installed on the system, e.g. via\n    # `pip install absl`\n    visibility = [\"//visibility:public\"],\n)\n\npy_library(\n    name = \"expect_numpy_installed\",\n    srcs_version = \"PY3\",\n    # This is a dummy rule used as a numpy dependency in open-source.\n    # We expect numpy to already be installed on the system, e.g. via\n    # `pip install numpy`\n    visibility = [\"//visibility:public\"],\n)\n\npy_library(\n    name = \"expect_pandas_installed\",\n    # This is a dummy rule used as a pandas dependency in open-source.\n    # We expect pandas to already be installed on the system, e.g. via\n    # `pip install pandas`\n    visibility = [\"//visibility:public\"],\n)\n\npy_library(\n    name = \"expect_h5py_installed\",\n    srcs_version = \"PY3\",\n    # This is a dummy rule used as a numpy dependency in open-source.\n    # We expect h5py to already be installed on the system, e.g. via\n    # `pip install h5py'\n    visibility = [\"//visibility:public\"],\n)\n\npy_library(\n    name = \"expect_six_installed\",\n    srcs_version = \"PY3\",\n    # This is a dummy rule used as a numpy dependency in open-source.\n    # We expect six to already be installed on the system, e.g. via\n    # `pip install six`\n    visibility = [\"//visibility:public\"],\n)\n\npy_library(\n    name = \"expect_tensorboard_installed\",\n    srcs_version = \"PY3\",\n    # This is a dummy rule used as a tensorboard dependency in open-source.\n    # We expect tensorboard to already be installed on the system, e.g. via\n    # `pip install tensorboard`.\n    visibility = [\"//visibility:public\"],\n)\n\npy_library(\n    name = \"expect_tensorflow_installed\",\n    srcs_version = \"PY3\",\n    # This is a dummy rule used as a numpy dependency in open-source.\n    # We expect tensorflow to already be installed on the system, e.g. via\n    # `pip install tensorflow` or `pip install tensorflow_gpu`\n    visibility = [\"//visibility:public\"],\n)\n\npy_library(\n    name = \"expect_tensorflow_keras_installed\",\n    srcs_version = \"PY3\",\n    # This is a dummy rule used as a numpy dependency in open-source.\n    # We expect tensorflow to already be installed on the system, e.g. via\n    # `pip install tensorflow` or `pip install tensorflow_gpu`\n    visibility = [\"//visibility:public\"],\n)\n\npy_library(\n    name = \"expect_proto_cpp_installed\",\n    srcs_version = \"PY3\",\n    # This is a dummy rule used as a numpy dependency in open-source.\n    # We expect protobuf cpp python to already be installed on the system.\n    visibility = [\"//visibility:public\"],\n)\n\n# The following targets are emulating cuda_py_test from //third_party/tensorflow:tensorflow.google.bzl\n# cuda_py_test cannot be used directly because the bzl file cannot be imported into tensorflow_estimator\n\npy_test(\n    name = \"distribute_strategy_estimator_integration_test\",\n    size = \"medium\",\n    srcs = [\"distribute_strategy_estimator_integration_test.py\"],\n    main = \"distribute_strategy_estimator_integration_test.py\",\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"tf_integration_test\",\n    ],\n    deps = [\n        \":estimator_py\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"distribute_strategy_estimator_integration_test_gpu\",\n    size = \"medium\",\n    srcs = [\"distribute_strategy_estimator_integration_test.py\"],\n    main = \"distribute_strategy_estimator_integration_test.py\",\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"cuda\",\n        \"gpu\",\n        \"multi_and_single_gpu\",\n        \"requires-gpu-nvidia\",\n        \"tf_integration_test\",\n    ],\n    deps = [\n        \":estimator_py\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"distribute_strategy_estimator_training_test\",\n    size = \"medium\",\n    srcs = [\"distribute_strategy_estimator_training_test.py\"],\n    main = \"distribute_strategy_estimator_training_test.py\",\n    python_version = \"PY3\",\n    shard_count = 48,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_oss\",  # b/140933379\n        # TODO(b/118768923): Re-enable {a,m,t}san test.\n        \"noasan\",\n        \"nomsan\",\n        \"notsan\",\n    ],\n    deps = [\n        \":estimator_py\",\n        # Placeholder for an internal build dep disabling tf2 behavior\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"distribute_strategy_estimator_training_test_gpu\",\n    size = \"medium\",\n    srcs = [\"distribute_strategy_estimator_training_test.py\"],\n    main = \"distribute_strategy_estimator_training_test.py\",\n    python_version = \"PY3\",\n    shard_count = 48,\n    srcs_version = \"PY3\",\n    tags = [\n        # TODO(b/118768923): Re-enable {a,m,t}san test.\n        \"noasan\",\n        \"nomsan\",\n        \"notsan\",\n        \"cuda\",\n        \"requires-gpu-nvidia\",\n        \"gpu\",\n        \"multi_and_single_gpu\",\n    ],\n    deps = [\n        \":estimator_py\",\n        # Placeholder for an internal build dep disabling tf2 behavior\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"canned_estimator_ds_integration_test\",\n    size = \"medium\",\n    srcs = [\"canned/canned_estimator_ds_integration_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"notap\",  # TODO(b/161835009): Re-enable.\n    ],\n    deps = [\n        \":estimator_py\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"canned_estimator_ds_integration_test_gpu\",\n    size = \"medium\",\n    srcs = [\"canned/canned_estimator_ds_integration_test.py\"],\n    main = \"canned/canned_estimator_ds_integration_test.py\",\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"cuda\",\n        \"gpu\",\n        \"multi_and_single_gpu\",\n        \"requires-gpu-nvidia\",\n        \"tf_integration_test\",\n    ],\n    deps = [\n        \":estimator_py\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/api/BUILD",
    "content": "# Placeholder: load aliased py_binary\nload(\"//tensorflow_estimator/python/estimator/api:api_gen.bzl\", \"ESTIMATOR_API_INIT_FILES_V1\", \"ESTIMATOR_API_INIT_FILES_V2\", \"generate_apis\")\n\npackage(default_visibility = [\"//tensorflow_estimator:internal\"])\n\nlicenses([\"notice\"])\n\n# This flag specifies whether Estimator 2.0 API should be built instead\n# of 1.* API. Note that Estimator 2.0 API is currently under development.\nconfig_setting(\n    name = \"api_version_2\",\n    define_values = {\"estimator_api_version\": \"2\"},\n)\n\npy_binary(\n    name = \"extractor_wrapper\",\n    srcs = [\"extractor_wrapper.py\"],\n    visibility = [\"//visibility:public\"],\n    deps = [\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",  # absl:app\n    ],\n)\n\npy_binary(\n    name = \"generator_wrapper\",\n    srcs = [\"generator_wrapper.py\"],\n    visibility = [\"//visibility:public\"],\n    deps = [\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",  # absl:app\n    ],\n)\n\ngenrule(\n    name = \"estimator_python_api_gen\",\n    srcs = select({\n        \"api_version_2\": [\"_v2/v2.py\"],\n        \"//conditions:default\": [\"_v1/v1.py\"],\n    }),\n    outs = [\"__init__.py\"],\n    cmd = select({\n        \"api_version_2\": \"cp $(location :_v2/v2.py) $(OUTS)\",\n        \"//conditions:default\": \"cp $(location :_v1/v1.py) $(OUTS)\",\n    }),\n)\n\ngenerate_apis(\n    name = \"estimator_python_api_gen_compat_v1\",\n    api_version = 1,\n    output_dir = \"_v1/\",\n    output_files = ESTIMATOR_API_INIT_FILES_V1,\n    output_package = \"tensorflow_estimator.python.estimator.api._v1\",\n    root_file_name = \"v1.py\",\n    visibility = [\"//visibility:public\"],\n)\n\ngenerate_apis(\n    name = \"estimator_python_api_gen_compat_v2\",\n    api_version = 2,\n    output_dir = \"_v2/\",\n    output_files = ESTIMATOR_API_INIT_FILES_V2,\n    output_package = \"tensorflow_estimator.python.estimator.api._v2\",\n    root_file_name = \"v2.py\",\n    visibility = [\"//visibility:public\"],\n)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/api/api_gen.bzl",
    "content": "\"\"\"Targets for generating TensorFlow Estimator Python API __init__.py files.\n\nThis bzl file is copied with slight modifications from\ntensorflow/python/tools/api/generator2/generate_api.bzl\nso that we can avoid needing to depend on TF source code in Bazel build.\n\nIt should be noted that because this file is executed during the build,\nand it imports TensorFlow code, that installing TensorFlow python package\nis required to Bazel build Estimator.\n\"\"\"\n\n# Placeholder: load PyInfo\nload(\"//tensorflow_estimator:estimator.bzl\", \"if_indexing_source_code\")\n\n_TARGET_PATTERNS = [\n    \"//tensorflow_estimator:\",\n    \"//tensorflow_estimator/\",\n]\n\n_DECORATOR = \"tensorflow_estimator.python.estimator.estimator_export.estimator_export\"\n\n_MODULE_PREFIX = \"\"\n\nESTIMATOR_API_INIT_FILES_V1 = [\n    \"__init__.py\",\n    \"estimator/__init__.py\",\n    \"estimator/experimental/__init__.py\",\n    \"estimator/export/__init__.py\",\n    \"estimator/inputs/__init__.py\",\n    \"estimator/tpu/__init__.py\",\n    \"estimator/tpu/experimental/__init__.py\",\n]\n\nESTIMATOR_API_INIT_FILES_V2 = [\n    \"__init__.py\",\n    \"estimator/__init__.py\",\n    \"estimator/experimental/__init__.py\",\n    \"estimator/export/__init__.py\",\n    \"estimator/inputs/__init__.py\",\n]\n\ndef _any_match(label):\n    full_target = \"//\" + label.package + \":\" + label.name\n    for pattern in _TARGET_PATTERNS:\n        if pattern in full_target:\n            return True\n    return False\n\ndef _join(path, *others):\n    result = path\n\n    for p in others:\n        if not result or result.endswith(\"/\"):\n            result += p\n        else:\n            result += \"/\" + p\n\n    return result\n\ndef _api_info_init(*, transitive_api):\n    if type(transitive_api) != type(depset()):\n        fail(\"ApiInfo.transitive_api must be a depset\")\n    return {\"transitive_api\": transitive_api}\n\nApiInfo, _new_api_info = provider(\n    doc = \"Provider for API symbols and docstrings extracted from Python files.\",\n    fields = {\n        \"transitive_api\": \"depset of files with extracted API.\",\n    },\n    init = _api_info_init,\n)\n\ndef _py_files(f):\n    if f.basename.endswith(\".py\") or f.basename.endswith(\".py3\"):\n        return f.path\n    return None\n\ndef _merge_py_info(\n        deps,\n        direct_sources = None,\n        direct_imports = None,\n        has_py2_only_sources = False,\n        has_py3_only_sources = False,\n        uses_shared_libraries = False):\n    transitive_sources = []\n    transitive_imports = []\n    for dep in deps:\n        if PyInfo in dep:\n            transitive_sources.append(dep[PyInfo].transitive_sources)\n            transitive_imports.append(dep[PyInfo].imports)\n            has_py2_only_sources = has_py2_only_sources or dep[PyInfo].has_py2_only_sources\n            has_py3_only_sources = has_py3_only_sources or dep[PyInfo].has_py3_only_sources\n            uses_shared_libraries = uses_shared_libraries or dep[PyInfo].uses_shared_libraries\n\n    return PyInfo(\n        transitive_sources = depset(direct = direct_sources, transitive = transitive_sources),\n        imports = depset(direct = direct_imports, transitive = transitive_imports),\n        has_py2_only_sources = has_py2_only_sources,\n        has_py3_only_sources = has_py3_only_sources,\n        uses_shared_libraries = uses_shared_libraries,\n    )\n\ndef _merge_api_info(\n        deps,\n        direct_api = None):\n    transitive_api = []\n    for dep in deps:\n        if ApiInfo in dep:\n            transitive_api.append(dep[ApiInfo].transitive_api)\n    return ApiInfo(transitive_api = depset(direct = direct_api, transitive = transitive_api))\n\ndef _api_extractor_impl(target, ctx):\n    direct_api = []\n\n    # Make sure the rule has a non-empty srcs attribute.\n    if (\n        _any_match(target.label) and\n        hasattr(ctx.rule.attr, \"srcs\") and\n        ctx.rule.attr.srcs\n    ):\n        output = ctx.actions.declare_file(\"_\".join([\n            target.label.name,\n            \"extracted_tensorflow_estimator_api.json\",\n        ]))\n\n        args = ctx.actions.args()\n        args.set_param_file_format(\"multiline\")\n        args.use_param_file(\"--flagfile=%s\")\n\n        args.add(\"--output\", output)\n        args.add(\"--decorator\", _DECORATOR)\n        args.add(\"--api_name\", \"tensorflow_estimator\")\n        args.add_all(ctx.rule.files.srcs, expand_directories = True, map_each = _py_files)\n\n        ctx.actions.run(\n            mnemonic = \"ExtractAPI\",\n            executable = ctx.executable._extractor_bin,\n            inputs = ctx.rule.files.srcs,\n            outputs = [output],\n            arguments = [args],\n            progress_message = \"Extracting tensorflow_estimator APIs for %{label} to %{output}.\",\n        )\n\n        direct_api.append(output)\n\n    return [\n        _merge_api_info(ctx.rule.attr.deps if hasattr(ctx.rule.attr, \"deps\") else [], direct_api = direct_api),\n    ]\n\napi_extractor = aspect(\n    doc = \"Extracts the exported API for the given target and its dependencies.\",\n    implementation = _api_extractor_impl,\n    attr_aspects = [\"deps\"],\n    provides = [ApiInfo],\n    # Currently the Python rules do not correctly advertise their providers.\n    # required_providers = [PyInfo],\n    attrs = {\n        \"_extractor_bin\": attr.label(\n            default = Label(\"//tensorflow_estimator/python/estimator/api:extractor_wrapper\"),\n            executable = True,\n            cfg = \"exec\",\n        ),\n    },\n)\n\ndef _extract_api_impl(ctx):\n    return [\n        _merge_api_info(ctx.attr.deps),\n        _merge_py_info(ctx.attr.deps),\n    ]\n\nextract_api = rule(\n    doc = \"Extract Python API for all targets in transitive dependencies.\",\n    implementation = _extract_api_impl,\n    attrs = {\n        \"deps\": attr.label_list(\n            doc = \"Targets to extract API from.\",\n            allow_empty = False,\n            aspects = [api_extractor],\n            providers = [PyInfo],\n            mandatory = True,\n        ),\n    },\n    provides = [ApiInfo, PyInfo],\n)\n\ndef _generate_api_impl(ctx):\n    args = ctx.actions.args()\n    args.set_param_file_format(\"multiline\")\n    args.use_param_file(\"--flagfile=%s\")\n\n    args.add_joined(\"--output_files\", ctx.outputs.output_files, join_with = \",\")\n    args.add(\"--output_dir\", _join(ctx.bin_dir.path, ctx.label.package, ctx.attr.output_dir))\n    if ctx.file.root_init_template:\n        args.add(\"--root_init_template\", ctx.file.root_init_template)\n    args.add(\"--apiversion\", ctx.attr.api_version)\n    args.add_joined(\"--compat_api_versions\", ctx.attr.compat_api_versions, join_with = \",\")\n    args.add_joined(\"--compat_init_templates\", ctx.files.compat_init_templates, join_with = \",\")\n    args.add(\"--output_package\", ctx.attr.output_package)\n    args.add_joined(\"--packages_to_ignore\", ctx.attr.packages_to_ignore, join_with = \",\")\n    if _MODULE_PREFIX:\n        args.add(\"--module_prefix\", _MODULE_PREFIX)\n    if ctx.attr.use_lazy_loading:\n        args.add(\"--use_lazy_loading\")\n    else:\n        args.add(\"--nouse_lazy_loading\")\n    if ctx.attr.proxy_module_root:\n        args.add(\"--proxy_module_root\", ctx.attr.proxy_module_root)\n    args.add_joined(\"--file_prefixes_to_strip\", [ctx.bin_dir.path, ctx.genfiles_dir.path], join_with = \",\")\n    if ctx.attr.root_file_name:\n        args.add(\"--root_file_name\", ctx.attr.root_file_name)\n\n    inputs = depset(transitive = [\n        dep[ApiInfo].transitive_api\n        for dep in ctx.attr.deps\n    ])\n    args.add_all(\n        inputs,\n        expand_directories = True,\n    )\n\n    transitive_inputs = [inputs]\n    if ctx.attr.root_init_template:\n        transitive_inputs.append(ctx.attr.root_init_template.files)\n\n    ctx.actions.run(\n        mnemonic = \"GenerateAPI\",\n        executable = ctx.executable._generator_bin,\n        inputs = depset(\n            direct = ctx.files.compat_init_templates,\n            transitive = transitive_inputs,\n        ),\n        outputs = ctx.outputs.output_files,\n        arguments = [args],\n        progress_message = \"Generating APIs for %{label} to %{output}.\",\n    )\n\ngenerate_api = rule(\n    doc = \"Generate Python API for all targets in transitive dependencies.\",\n    implementation = _generate_api_impl,\n    attrs = {\n        \"deps\": attr.label_list(\n            doc = \"extract_api targets to generate API from.\",\n            allow_empty = True,\n            providers = [ApiInfo, PyInfo],\n            mandatory = True,\n        ),\n        \"root_init_template\": attr.label(\n            doc = \"Template for the top level __init__.py file\",\n            allow_single_file = True,\n        ),\n        \"api_version\": attr.int(\n            doc = \"The API version to generate (1 or 2)\",\n            values = [1, 2],\n        ),\n        \"compat_api_versions\": attr.int_list(\n            doc = \"Additional versions to generate in compat/ subdirectory.\",\n        ),\n        \"compat_init_templates\": attr.label_list(\n            doc = \"Template for top-level __init__files under compat modules. This list must be \" +\n                  \"in the same order as the list of versions in compat_apiversions\",\n            allow_files = True,\n        ),\n        \"output_package\": attr.string(\n            doc = \"Root output package.\",\n        ),\n        \"output_dir\": attr.string(\n            doc = \"Subdirectory to output API to. If non-empty, must end with '/'.\",\n        ),\n        \"proxy_module_root\": attr.string(\n            doc = \"Module root for proxy-import format. If specified, proxy files with \" +\n                  \"`from proxy_module_root.proxy_module import *` will be created to enable \" +\n                  \"import resolution under TensorFlow.\",\n        ),\n        \"output_files\": attr.output_list(\n            doc = \"List of __init__.py files that should be generated. This list should include \" +\n                  \"file name for every module exported using tf_export. For e.g. if an op is \" +\n                  \"decorated with @tf_export('module1.module2', 'module3'). Then, output_files \" +\n                  \"should include module1/module2/__init__.py and module3/__init__.py.\",\n        ),\n        \"use_lazy_loading\": attr.bool(\n            doc = \"If true, lazy load imports in the generated API rather then imporing them all statically.\",\n        ),\n        \"packages_to_ignore\": attr.string_list(\n            doc = \"List of packages to ignore tf_exports from.\",\n        ),\n        \"root_file_name\": attr.string(\n            doc = \"The file name that should be generated for the top level API.\",\n        ),\n        \"_generator_bin\": attr.label(\n            default = Label(\"//tensorflow_estimator/python/estimator/api:generator_wrapper\"),\n            executable = True,\n            cfg = \"exec\",\n        ),\n    },\n)\n\ndef generate_apis(\n        name,\n        deps = [\n            \"//tensorflow_estimator/python/estimator:estimator_py\",\n            # \"//third_party/tensorflow/lite/python:analyzer\",\n            # \"//third_party/tensorflow/lite/python:lite\",\n            # \"//third_party/tensorflow/lite/python/authoring\",\n        ],\n        output_files = ESTIMATOR_API_INIT_FILES_V2,\n        root_init_template = None,\n        api_version = 2,\n        compat_api_versions = [],\n        compat_init_templates = [],\n        output_package = \"tensorflow_estimator.python.estimator.api\",\n        output_dir = \"\",\n        proxy_module_root = None,\n        packages_to_ignore = [],\n        root_file_name = \"__init__.py\",\n        visibility = [\"//visibility:private\"]):\n    \"\"\"Generate TensorFlow APIs for a set of libraries.\n\n    Args:\n        name: name of generate_api target.\n        deps: python_library targets to serve as roots for extracting APIs.\n        output_files: The list of files that the API generator is exected to create.\n        root_init_template: The template for the top level __init__.py file generated.\n            \"#API IMPORTS PLACEHOLDER\" comment will be replaced with imports.\n        api_version: THhe API version to generate. (1 or 2)\n        compat_api_versions: Additional versions to generate in compat/ subdirectory.\n        compat_init_templates: Template for top level __init__.py files under the compat modules.\n            The list must be in the same order as the list of versions in 'compat_api_versions'\n        output_package: Root output package.\n        output_dir: Directory where the generated output files are placed. This should be a prefix\n            of every directory in 'output_files'\n        proxy_module_root: Module root for proxy-import format. If specified, proxy files with\n            `from proxy_module_root.proxy_module import *` will be created to enable import\n            resolution under TensorFlow.\n        packages_to_ignore: List of packages to ignore tf_exports from.\n        root_file_name: The file name that should be generated for the top level API.\n        visibility: Visibility of the target containing the generated files.\n    \"\"\"\n    extract_name = name + \".extract-tensorflow-estimator\"\n    extract_api(\n        name = extract_name,\n        deps = deps,\n        visibility = [\"//visibility:private\"],\n    )\n\n    if proxy_module_root != None:\n        # Avoid conflicts between the __init__.py file of TensorFlow and proxy module.\n        output_files = [f for f in output_files if f != \"__init__.py\"]\n\n    if root_file_name != None:\n        output_files = [f if f != \"__init__.py\" else root_file_name for f in output_files]\n\n    all_output_files = [_join(output_dir, f) for f in output_files]\n\n    generate_api(\n        name = name,\n        deps = [\":\" + extract_name],\n        output_files = all_output_files,\n        output_dir = output_dir,\n        root_init_template = root_init_template,\n        compat_api_versions = compat_api_versions,\n        compat_init_templates = compat_init_templates,\n        api_version = api_version,\n        proxy_module_root = proxy_module_root,\n        visibility = visibility,\n        packages_to_ignore = packages_to_ignore,\n        use_lazy_loading = False,\n        output_package = output_package,\n        root_file_name = root_file_name,\n    )\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/api/extractor_wrapper.py",
    "content": "# Copyright 2023 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Thin wrapper to call TensorFlow's API extractor script.\"\"\"\nfrom absl import app\n\nfrom tensorflow.python.tools.api.generator2.extractor import extractor\n\nif __name__ == \"__main__\":\n  app.run(extractor.main)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/api/generator_wrapper.py",
    "content": "# Copyright 2023 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Thin wrapper to call TensorFlow's API generator script.\"\"\"\nfrom absl import app\nfrom tensorflow.python.tools.api.generator2.generator import generator\n\nif __name__ == \"__main__\":\n  app.run(generator.main)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/__init__.py",
    "content": ""
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/baseline.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Baseline estimators.\n\nBaseline estimators are bias-only estimators that can be used for debugging\nand as simple baselines.\n\nExample:\n\n```\n# Build BaselineClassifier\nclassifier = BaselineClassifier(n_classes=3)\n\n# Input builders\ndef input_fn_train():\n  # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n  # index.\n  pass\n\ndef input_fn_eval():\n  # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n  # index.\n  pass\n\n# Fit model.\nclassifier.train(input_fn=input_fn_train)\n\n# Evaluate cross entropy between the test and train labels.\nloss = classifier.evaluate(input_fn=input_fn_eval)[\"loss\"]\n\n# predict outputs the probability distribution of the classes as seen in\n# training.\npredictions = classifier.predict(new_samples)\n```\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column as feature_column_v1\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import optimizers\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.head import head_utils\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n# The default learning rate of 0.3 is a historical artifact of the initial\n# implementation, but seems a reasonable choice.\n_LEARNING_RATE = 0.3\n\n\ndef _get_weight_column_key(weight_column):\n  if weight_column is None:\n    return None\n  if isinstance(weight_column, six.string_types):\n    return weight_column\n  if not isinstance(weight_column, feature_column_v1._NumericColumn):  # pylint: disable=protected-access\n    raise TypeError('Weight column must be either a string or _NumericColumn.'\n                    ' Given type: {}.'.format(type(weight_column)))\n  return weight_column.key()\n\n\ndef _get_weight_column_key_v2(weight_column):\n  if weight_column is None:\n    return None\n  if isinstance(weight_column, six.string_types):\n    return weight_column\n  if not isinstance(weight_column, feature_column_v2.NumericColumn):\n    raise TypeError('Weight column must be either a string or NumericColumn. '\n                    'Given type: {}.'.format(type(weight_column)))\n  return weight_column.key()\n\n\ndef _get_batch_size_and_size_checks(features, weight_column_key):\n  \"\"\"Returns batch_size and size_checks.\"\"\"\n  size_checks = []\n  batch_size = None\n\n  # The first dimension is assumed to be a batch size and must be consistent\n  # among all of the features.\n  for key, feature in features.items():\n    # Skip weight_column to ensure we don't add size checks to it.\n    # These would introduce a dependency on the weight at serving time.\n    if key == weight_column_key:\n      continue\n    first_dim = tf.compat.v1.shape(feature)[0]\n    if batch_size is None:\n      batch_size = first_dim\n    else:\n      size_checks.append(\n          tf.compat.v1.debugging.assert_equal(batch_size, first_dim))\n\n  return size_checks, batch_size\n\n\ndef _baseline_logit_fn_builder(num_outputs, weight_column=None):\n  \"\"\"Function builder for a baseline logit_fn.\n\n  Args:\n    num_outputs: Number of outputs for the model.\n    weight_column: A string or a `_NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It will be multiplied by the loss of the example.\n\n  Returns:\n    A logit_fn (see below).\n  \"\"\"\n\n  def baseline_logit_fn(features):\n    \"\"\"Baseline model logit_fn.\n\n    The baseline model simply learns a bias, so the output logits are a\n    `Variable` with one weight for each output that learns the bias for the\n    corresponding output.\n\n    Args:\n      features: The first item returned from the `input_fn` passed to `train`,\n        `evaluate`, and `predict`. This should be a single `Tensor` or dict with\n        `Tensor` values.\n\n    Returns:\n      A `Tensor` representing the logits.\n    \"\"\"\n    weight_column_key = _get_weight_column_key(weight_column)\n    size_checks, batch_size = _get_batch_size_and_size_checks(\n        features, weight_column_key)\n    with tf.control_dependencies(size_checks):\n      with tf.compat.v1.variable_scope('baseline'):\n        bias = tf.compat.v1.get_variable(\n            'bias',\n            shape=[num_outputs],\n            initializer=tf.compat.v1.initializers.zeros)\n        return tf.math.multiply(bias, tf.ones([batch_size, num_outputs]))\n\n  return baseline_logit_fn\n\n\ndef _baseline_model_fn(features,\n                       labels,\n                       mode,\n                       head,\n                       optimizer,\n                       weight_column=None,\n                       config=None):\n  \"\"\"Model_fn for baseline models.\n\n  Args:\n    features: `Tensor` or dict of `Tensor` (depends on data passed to `train`).\n    labels: `Tensor` of labels that are compatible with the `Head` instance.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `Head` instance.\n    optimizer: String, `tf.Optimizer` object, or callable that creates the\n      optimizer to use for training. If not specified, will use `FtrlOptimizer`\n      with a default learning rate of 0.3.\n    weight_column: A string or a `_NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It will be multiplied by the loss of the example.\n    config: `RunConfig` object to configure the runtime settings.\n\n  Raises:\n    KeyError: If weight column is specified but not present.\n    ValueError: If features is an empty dictionary.\n\n  Returns:\n    An `EstimatorSpec` instance.\n  \"\"\"\n  del config  # Unused.\n\n  logit_fn = _baseline_logit_fn_builder(head.logits_dimension, weight_column)\n  logits = logit_fn(features)\n\n  def train_op_fn(loss):\n    opt = optimizers.get_optimizer_instance(\n        optimizer, learning_rate=_LEARNING_RATE)\n    return opt.minimize(loss, global_step=tf.compat.v1.train.get_global_step())\n\n  return head.create_estimator_spec(\n      features=features,\n      mode=mode,\n      logits=logits,\n      labels=labels,\n      train_op_fn=train_op_fn)\n\n\ndef _baseline_model_fn_builder_v2(features, num_outputs, weight_column=None):\n  \"\"\"Function builder for a baseline logit_fn.\n\n  Args:\n    features: The first item returned from the `input_fn` passed to `train`,\n      `evaluate`, and `predict`. This should be a single `Tensor` or dict with\n      `Tensor` values.\n    num_outputs: Number of outputs for the model.\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It will be multiplied by the loss of the example.\n\n  Returns:\n    A list of trainable variables and a `Tensor` representing the logits.\n  \"\"\"\n  weight_column_key = _get_weight_column_key_v2(weight_column)\n  size_checks, batch_size = _get_batch_size_and_size_checks(\n      features, weight_column_key)\n  with tf.control_dependencies(size_checks):\n    with ops.name_scope('baseline'):\n      bias = tf.Variable(initial_value=tf.zeros([num_outputs]), name='bias')\n      logits = tf.math.multiply(bias, tf.ones([batch_size, num_outputs]))\n  return [bias], logits\n\n\ndef _baseline_model_fn_v2(\n    features,\n    labels,\n    mode,\n    head,\n    optimizer,\n    weight_column=None,\n    config=None,\n    loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE):\n  \"\"\"Model_fn for baseline models.\n\n  Args:\n    features: `Tensor` or dict of `Tensor` (depends on data passed to `train`).\n    labels: `Tensor` of labels that are compatible with the `Head` instance.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `Head` instance.\n    optimizer: String, `tf.Optimizer` object, or callable that creates the\n      optimizer to use for training. If not specified, will use `FtrlOptimizer`\n      with a default learning rate of 0.3.\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It will be multiplied by the loss of the example.\n    config: `RunConfig` object to configure the runtime settings.\n    loss_reduction: One of `tf_keras.losses.Reduction` except `NONE`. Describes\n      how to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n\n  Raises:\n    KeyError: If weight column is specified but not present.\n    ValueError: If features is an empty dictionary.\n\n  Returns:\n    An `EstimatorSpec` instance.\n  \"\"\"\n  del config  # Unused.\n\n  trainable_variables, logits = _baseline_model_fn_builder_v2(\n      features, head.logits_dimension, weight_column)\n\n  # In TRAIN mode, create optimizer and assign global_step variable to\n  # optimizer.iterations to make global_step increased correctly, as Hooks\n  # relies on global step as step counter.\n  if mode == ModeKeys.TRAIN:\n    opt = optimizers.get_optimizer_instance_v2(\n        optimizer, learning_rate=_LEARNING_RATE)\n    opt.iterations = tf.compat.v1.train.get_or_create_global_step()\n\n  def train_op_fn(loss):\n    # Scale loss by number of replicas.\n    if loss_reduction == tf.losses.Reduction.SUM_OVER_BATCH_SIZE:\n      num_replicas = tf.distribute.get_strategy().num_replicas_in_sync\n      if num_replicas > 1:\n        loss *= (1. / num_replicas)\n    return opt.get_updates(loss, trainable_variables)[0]\n\n  return head.create_estimator_spec(\n      features=features,\n      mode=mode,\n      logits=logits,\n      labels=labels,\n      train_op_fn=train_op_fn)\n\n\n@estimator_export('estimator.BaselineClassifier', v1=[])\nclass BaselineClassifierV2(estimator.EstimatorV2):\n  \"\"\"A classifier that can establish a simple baseline.\n\n  This classifier ignores feature values and will learn to predict the average\n  value of each label. For single-label problems, this will predict the\n  probability distribution of the classes as seen in the labels. For multi-label\n  problems, this will predict the fraction of examples that are positive for\n  each class.\n\n  Example:\n\n  ```python\n\n  # Build BaselineClassifier\n  classifier = tf.estimator.BaselineClassifier(n_classes=3)\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n\n  # Fit model.\n  classifier.train(input_fn=input_fn_train)\n\n  # Evaluate cross entropy between the test and train labels.\n  loss = classifier.evaluate(input_fn=input_fn_eval)[\"loss\"]\n\n  # predict outputs the probability distribution of the classes as seen in\n  # training.\n  predictions = classifier.predict(new_samples)\n\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n    otherwise there will be a `KeyError`:\n\n  * if `weight_column` is not `None`, a feature with\n     `key=weight_column` whose value is a `Tensor`.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               model_dir=None,\n               n_classes=2,\n               weight_column=None,\n               label_vocabulary=None,\n               optimizer='Ftrl',\n               config=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE):\n    \"\"\"Initializes a BaselineClassifier instance.\n\n    Args:\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      n_classes: number of label classes. Default is binary classification.\n        It must be greater than 1. Note: Class labels are integers representing\n          the class index (i.e. values from 0 to n_classes-1). For arbitrary\n          label values (e.g. string labels), convert to class indices first.\n      weight_column: A string or a `NumericColumn` created by\n        `tf.feature_column.numeric_column` defining feature column representing\n        weights. It will be multiplied by the loss of the example.\n      label_vocabulary: Optional list of strings with size `[n_classes]`\n        defining the label vocabulary. Only supported for `n_classes` > 2.\n      optimizer: String, `tf_keras.optimizers.*` object, or callable that\n        creates the optimizer to use for training. If not specified, will use\n        `Ftrl` as the default optimizer.\n      config: `RunConfig` object to configure the runtime settings.\n      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how\n        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n\n    Returns:\n      A `BaselineClassifier` estimator.\n\n    Raises:\n      ValueError: If `n_classes` < 2.\n    \"\"\"\n    head = head_utils.binary_or_multi_class_head(\n        n_classes,\n        weight_column=weight_column,\n        label_vocabulary=label_vocabulary,\n        loss_reduction=loss_reduction)\n\n    def _model_fn(features, labels, mode, config):\n      return _baseline_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          optimizer=optimizer,\n          weight_column=weight_column,\n          config=config,\n          loss_reduction=loss_reduction)\n\n    super(BaselineClassifierV2, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config)\n\n\n@estimator_export(v1=['estimator.BaselineClassifier'])  # pylint: disable=missing-docstring\nclass BaselineClassifier(estimator.Estimator):\n  __doc__ = BaselineClassifierV2.__doc__.replace('SUM_OVER_BATCH_SIZE', 'SUM')\n\n  def __init__(self,\n               model_dir=None,\n               n_classes=2,\n               weight_column=None,\n               label_vocabulary=None,\n               optimizer='Ftrl',\n               config=None,\n               loss_reduction=tf.compat.v1.losses.Reduction.SUM):\n    head = head_lib._binary_logistic_or_multi_class_head(  # pylint: disable=protected-access\n        n_classes, weight_column, label_vocabulary, loss_reduction)\n\n    def _model_fn(features, labels, mode, config):\n      return _baseline_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          optimizer=optimizer,\n          weight_column=weight_column,\n          config=config)\n\n    super(BaselineClassifier, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config)\n\n\n@estimator_export('estimator.BaselineEstimator', v1=[])\nclass BaselineEstimatorV2(estimator.EstimatorV2):\n  \"\"\"An estimator that can establish a simple baseline.\n\n  The estimator uses a user-specified head.\n\n  This estimator ignores feature values and will learn to predict the average\n  value of each label. E.g. for single-label classification problems, this will\n  predict the probability distribution of the classes as seen in the labels.\n  For multi-label classification problems, it will predict the ratio of examples\n  that contain each class.\n\n  Example:\n\n  ```python\n\n  # Build baseline multi-label classifier.\n  estimator = tf.estimator.BaselineEstimator(\n      head=tf.estimator.MultiLabelHead(n_classes=3))\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n\n  # Fit model.\n  estimator.train(input_fn=input_fn_train)\n\n  # Evaluates cross entropy between the test and train labels.\n  loss = estimator.evaluate(input_fn=input_fn_eval)[\"loss\"]\n\n  # For each class, predicts the ratio of training examples that contain the\n  # class.\n  predictions = estimator.predict(new_samples)\n\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n    otherwise there will be a `KeyError`:\n\n  * if `weight_column` is specified in the `head` constructor (and not None) for\n    the head passed to BaselineEstimator's constructor, a feature with\n    `key=weight_column` whose value is a `Tensor`.\n  \"\"\"\n\n  def __init__(self, head, model_dir=None, optimizer='Ftrl', config=None):\n    \"\"\"Initializes a BaselineEstimator instance.\n\n    Args:\n      head: A `Head` instance constructed with a method such as\n        `tf.estimator.MultiLabelHead`.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      optimizer: String, `tf_keras.optimizers.*` object, or callable that\n        creates the optimizer to use for training. If not specified, will use\n        `Ftrl` as the default optimizer.\n      config: `RunConfig` object to configure the runtime settings.\n    \"\"\"\n\n    def _model_fn(features, labels, mode, config):\n      return _baseline_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          optimizer=optimizer,\n          config=config)\n\n    super(BaselineEstimatorV2, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config)\n\n\n@estimator_export(v1=['estimator.BaselineEstimator'])  # pylint: disable=missing-docstring\nclass BaselineEstimator(estimator.Estimator):\n  __doc__ = BaselineEstimatorV2.__doc__\n\n  def __init__(self, head, model_dir=None, optimizer='Ftrl', config=None):\n\n    def _model_fn(features, labels, mode, config):\n      return _baseline_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          optimizer=optimizer,\n          config=config)\n\n    super(BaselineEstimator, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config)\n\n\n@estimator_export('estimator.BaselineRegressor', v1=[])\nclass BaselineRegressorV2(estimator.EstimatorV2):\n  \"\"\"A regressor that can establish a simple baseline.\n\n  This regressor ignores feature values and will learn to predict the average\n  value of each label.\n\n  Example:\n\n  ```python\n\n  # Build BaselineRegressor\n  regressor = tf.estimator.BaselineRegressor()\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n\n  # Fit model.\n  regressor.train(input_fn=input_fn_train)\n\n  # Evaluate squared-loss between the test and train targets.\n  loss = regressor.evaluate(input_fn=input_fn_eval)[\"loss\"]\n\n  # predict outputs the mean value seen during training.\n  predictions = regressor.predict(new_samples)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n    otherwise there will be a `KeyError`:\n\n  * if `weight_column` is not `None`, a feature with\n     `key=weight_column` whose value is a `Tensor`.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               model_dir=None,\n               label_dimension=1,\n               weight_column=None,\n               optimizer='Ftrl',\n               config=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE):\n    \"\"\"Initializes a BaselineRegressor instance.\n\n    Args:\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      label_dimension: Number of regression targets per example. This is the\n        size of the last dimension of the labels and logits `Tensor` objects\n        (typically, these have shape `[batch_size, label_dimension]`).\n      weight_column: A string or a `_NumericColumn` created by\n        `tf.feature_column.numeric_column` defining feature column representing\n        weights. It will be multiplied by the loss of the example.\n      optimizer: String, `tf_keras.optimizers.*` object, or callable that\n        creates the optimizer to use for training. If not specified, will use\n        `Ftrl` as the default optimizer.\n      config: `RunConfig` object to configure the runtime settings.\n      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how\n        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n\n    Returns:\n      A `BaselineRegressor` estimator.\n    \"\"\"\n    head = regression_head.RegressionHead(\n        label_dimension=label_dimension,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction)\n\n    def _model_fn(features, labels, mode, config):\n      return _baseline_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          optimizer=optimizer,\n          config=config)\n\n    super(BaselineRegressorV2, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config)\n\n\n@estimator_export(v1=['estimator.BaselineRegressor'])  # pylint: disable=missing-docstring\nclass BaselineRegressor(estimator.Estimator):\n  __doc__ = BaselineRegressorV2.__doc__.replace('SUM_OVER_BATCH_SIZE', 'SUM')\n\n  def __init__(self,\n               model_dir=None,\n               label_dimension=1,\n               weight_column=None,\n               optimizer='Ftrl',\n               config=None,\n               loss_reduction=tf.compat.v1.losses.Reduction.SUM):\n    head = head_lib._regression_head(  # pylint: disable=protected-access\n        label_dimension=label_dimension,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction)\n\n    def _model_fn(features, labels, mode, config):\n      return _baseline_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          optimizer=optimizer,\n          config=config)\n\n    super(BaselineRegressor, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/baseline_estimator_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for BaselineEstimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import baseline\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n# Names of variables created by model.\nBIAS_NAME = 'baseline/bias'\n\n\ndef assert_close(expected, actual, rtol=1e-04, name='assert_close'):\n  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:\n    expected = ops.convert_to_tensor(expected, name='expected')\n    actual = ops.convert_to_tensor(actual, name='actual')\n    rdiff = tf.math.abs(expected - actual, 'diff') / tf.math.abs(expected)\n    rtol = ops.convert_to_tensor(rtol, name='rtol')\n    return tf.compat.v1.debugging.assert_less(\n        rdiff,\n        rtol,\n        data=('Condition expected =~ actual did not hold element-wise:'\n              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,\n              'rtol = ', rtol,),\n        name=scope)\n\n\ndef save_variables_to_ckpt(model_dir):\n  init_all_op = [tf.compat.v1.initializers.global_variables()]\n  with tf.compat.v1.Session() as sess:\n    sess.run(init_all_op)\n    tf.compat.v1.train.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))\n\n\ndef _baseline_estimator_fn(weight_column=None, label_dimension=1, **kwargs):\n  return baseline.BaselineEstimatorV2(\n      head=regression_head.RegressionHead(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE),\n      **kwargs)\n\n\ndef mock_optimizer_v2(testcase, expected_loss=None):\n  \"\"\"Creates a mock optimizer to test the train method.\n\n  Args:\n    testcase: A TestCase instance.\n    expected_loss: If given, will assert the loss value.\n\n  Returns:\n    A mock Optimizer.\n  \"\"\"\n  expected_var_names = ['%s:0' % BIAS_NAME]\n\n  class _Optimizer(tf_keras.optimizers.legacy.Optimizer):\n\n    def get_updates(self, loss, params):\n      trainable_vars = params\n      testcase.assertItemsEqual(expected_var_names,\n                                [var.name for var in trainable_vars])\n\n      # Verify loss. We can't check the value directly, so we add an assert op.\n      testcase.assertEquals(0, loss.shape.ndims)\n      if expected_loss is None:\n        if self.iterations is not None:\n          return [self.iterations.assign_add(1).op]\n        return [tf.no_op()]\n      assert_loss = assert_close(\n          tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n          loss,\n          name='assert_loss')\n      with tf.control_dependencies((assert_loss,)):\n        if self.iterations is not None:\n          return [self.iterations.assign_add(1).op]\n        return [tf.no_op()]\n\n    def get_config(self):\n      config = super(_Optimizer, self).get_config()\n      return config\n\n  optimizer = _Optimizer(name='my_optimizer')\n\n  return optimizer\n\n\nclass BaselineEstimatorEvaluationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_evaluation_batch(self):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir)\n    eval_metrics = baseline_estimator.evaluate(\n        input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1)\n\n    # Logit is bias = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the sum over batch size = (9 + 9) / 2 = 9\n    # Average loss is the average over batch = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 9.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_weights(self):\n    \"\"\"Tests evaluation with weights.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))}\n      labels = ((10.,), (10.,))\n      return features, labels\n\n    baseline_estimator = _baseline_estimator_fn(\n        weight_column='weights', model_dir=self._model_dir)\n    eval_metrics = baseline_estimator.evaluate(input_fn=_input_fn, steps=1)\n\n    # Logit is bias = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the weighted sum over batch size= (9 + 2*9) / 2 = 13.5\n    # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 13.5,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_for_multi_dimensions(self):\n    label_dim = 2\n    with tf.Graph().as_default():\n      tf.Variable([46.0, 58.0], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_estimator = _baseline_estimator_fn(\n        label_dimension=label_dim, model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': np.array([[2., 4., 5.]]),\n        },\n        y=np.array([[46., 58.]]),\n        batch_size=1,\n        num_epochs=None,\n        shuffle=False)\n    eval_metrics = baseline_estimator.evaluate(input_fn=input_fn, steps=1)\n\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is bias which is [46, 58]\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n\nclass BaselineEstimatorPredictTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_1d(self):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[2.]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = baseline_estimator.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x * weight + bias = 2. * 10. + .2 = 20.2\n    self.assertAllClose([[.2]], predicted_scores)\n\n  def testMultiDim(self):\n    \"\"\"Tests predict when all variables are multi-dimenstional.\"\"\"\n    batch_size = 2\n    label_dimension = 3\n    with tf.Graph().as_default():\n      tf.Variable(  # shape=[label_dimension]\n          [.2, .4, .6], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_estimator = _baseline_estimator_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        # x shape=[batch_size, x_dim]\n        x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predictions = baseline_estimator.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # score = bias, shape=[batch_size, label_dimension]\n    self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], predicted_scores)\n\n\nclass BaselineEstimatorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, prediction_length):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = _baseline_estimator_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['predictions'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, label_dimension), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n\nclass BaselineEstimatorTrainingTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _assert_checkpoint(self,\n                         label_dimension,\n                         expected_global_step,\n                         expected_bias=None):\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([label_dimension], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertEqual(expected_bias,\n                       tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def testFromScratch(self):\n    # Create BaselineRegressor.\n    label = 5.\n    age = 17\n    # loss = (logits - label)^2 = (0 - 5.)^2 = 25.\n    mock_optimizer = mock_optimizer_v2(self, expected_loss=25.)\n    baseline_estimator = _baseline_estimator_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_estimator.train(\n        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)\n    self.assertEqual(\n        num_steps,\n        baseline_estimator.get_variable_value(mock_optimizer.iterations.name))\n    self._assert_checkpoint(\n        label_dimension=1, expected_global_step=num_steps, expected_bias=[0.])\n\n  def testFromCheckpoint(self):\n    # Create initial checkpoint.\n    bias = 7.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = bias = 6.\n    # loss = (logits - label)^2 = (7 - 5)^2 = 4\n    mock_optimizer = mock_optimizer_v2(self, expected_loss=4.)\n    baseline_estimator = _baseline_estimator_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_estimator.train(\n        input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps)\n    self.assertEqual(\n        initial_global_step + num_steps,\n        baseline_estimator.get_variable_value(mock_optimizer.iterations.name))\n    self._assert_checkpoint(\n        label_dimension=1,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=[bias])\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/baseline_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for baseline.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\nimport os\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import baseline\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n# pylint rules which are disabled by default for test files.\n# pylint: disable=invalid-name,protected-access,missing-docstring\n\n# Names of variables created by model.\nBIAS_NAME = 'baseline/bias'\n\n\ndef assert_close(expected, actual, rtol=1e-04, name='assert_close'):\n  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:\n    expected = ops.convert_to_tensor(expected, name='expected')\n    actual = ops.convert_to_tensor(actual, name='actual')\n    rdiff = tf.math.abs(expected - actual, 'diff') / tf.math.abs(expected)\n    rtol = ops.convert_to_tensor(rtol, name='rtol')\n    return tf.compat.v1.debugging.assert_less(\n        rdiff,\n        rtol,\n        data=('Condition expected =~ actual did not hold element-wise:'\n              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,\n              'rtol = ', rtol,),\n        name=scope)\n\n\ndef save_variables_to_ckpt(model_dir):\n  init_all_op = [tf.compat.v1.initializers.global_variables()]\n  with tf.compat.v1.Session() as sess:\n    sess.run(init_all_op)\n    tf.compat.v1.train.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))\n\n\ndef queue_parsed_features(feature_map):\n  tensors_to_enqueue = []\n  keys = []\n  for key, tensor in six.iteritems(feature_map):\n    keys.append(key)\n    tensors_to_enqueue.append(tensor)\n  queue_dtypes = [x.dtype for x in tensors_to_enqueue]\n  input_queue = tf.queue.FIFOQueue(capacity=100, dtypes=queue_dtypes)\n  tf.compat.v1.train.queue_runner.add_queue_runner(\n      tf.compat.v1.train.queue_runner.QueueRunner(\n          input_queue, [input_queue.enqueue(tensors_to_enqueue)]))\n  dequeued_tensors = input_queue.dequeue()\n  return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}\n\n\ndef sorted_key_dict(unsorted_dict):\n  return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}\n\n\ndef sigmoid(x):\n  return 1 / (1 + np.exp(-1.0 * x))\n\n\ndef _baseline_regressor_fn(*args, **kwargs):\n  return baseline.BaselineRegressorV2(*args, **kwargs)\n\n\ndef _baseline_classifier_fn(*args, **kwargs):\n  return baseline.BaselineClassifierV2(*args, **kwargs)\n\n\ndef mock_optimizer_v2(testcase, expected_loss=None):\n  \"\"\"Creates a mock optimizer to test the train method.\n\n  Args:\n    testcase: A TestCase instance.\n    expected_loss: If given, will assert the loss value.\n\n  Returns:\n    A mock Optimizer.\n  \"\"\"\n  expected_var_names = ['%s:0' % BIAS_NAME]\n\n  class _Optimizer(tf_keras.optimizers.legacy.Optimizer):\n\n    def get_updates(self, loss, params):\n      trainable_vars = params\n      testcase.assertItemsEqual(expected_var_names,\n                                [var.name for var in trainable_vars])\n\n      # Verify loss. We can't check the value directly, so we add an assert op.\n      testcase.assertEquals(0, loss.shape.ndims)\n      if expected_loss is None:\n        if self.iterations is not None:\n          return [self.iterations.assign_add(1).op]\n        return [tf.no_op()]\n      assert_loss = assert_close(\n          tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n          loss,\n          name='assert_loss')\n      with tf.control_dependencies((assert_loss,)):\n        if self.iterations is not None:\n          return [self.iterations.assign_add(1).op]\n        return [tf.no_op()]\n\n    def get_config(self):\n      config = super(_Optimizer, self).get_config()\n      return config\n\n  optimizer = _Optimizer(name='my_optimizer')\n\n  return optimizer\n\n\n# Tests for Baseline Regressor.\n\n# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.\n\n\nclass BaselineRegressorEvaluationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_evaluation_for_simple_data(self):\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)\n    eval_metrics = baseline_regressor.evaluate(\n        input_fn=lambda: ({\n            'age': ((1,),)\n        }, ((10.,),)), steps=1)\n\n    # Logit is bias = 13, while label is 10. Loss is 3**2 = 9.\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 9.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_batch(self):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)\n    eval_metrics = baseline_regressor.evaluate(\n        input_fn=lambda: ({\n            'age': ((1,), (1,))\n        }, ((10.,), (10.,))), steps=1)\n\n    # Logit is bias = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the sum over batch size = (9 + 9) / 2 = 9\n    # Average loss is the average over batch = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 9.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_weights(self):\n    \"\"\"Tests evaluation with weights.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))}\n      labels = ((10.,), (10.,))\n      return features, labels\n\n    baseline_regressor = _baseline_regressor_fn(\n        weight_column='weights', model_dir=self._model_dir)\n    eval_metrics = baseline_regressor.evaluate(input_fn=_input_fn, steps=1)\n\n    # Logit is bias = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the weighted sum over batch size = (9 + 2*9) / 2 = 13.5\n    # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 13.5,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_for_multi_dimensions(self):\n    label_dim = 2\n    with tf.Graph().as_default():\n      tf.Variable([46.0, 58.0], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(\n        label_dimension=label_dim, model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': np.array([[2., 4., 5.]]),\n        },\n        y=np.array([[46., 58.]]),\n        batch_size=1,\n        num_epochs=None,\n        shuffle=False)\n    eval_metrics = baseline_regressor.evaluate(input_fn=input_fn, steps=1)\n\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is bias which is [46, 58]\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n\nclass BaselineRegressorPredictTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_1d(self):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[2.]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = baseline_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x * weight + bias = 2. * 10. + .2 = 20.2\n    self.assertAllClose([[.2]], predicted_scores)\n\n  def testMultiDim(self):\n    \"\"\"Tests predict when all variables are multi-dimenstional.\"\"\"\n    batch_size = 2\n    label_dimension = 3\n    with tf.Graph().as_default():\n      tf.Variable(  # shape=[label_dimension]\n          [.2, .4, .6], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        # x shape=[batch_size, x_dim]\n        x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predictions = baseline_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # score = bias, shape=[batch_size, label_dimension]\n    self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], predicted_scores)\n\n\nclass BaselineRegressorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, prediction_length):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = _baseline_regressor_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['predictions'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, label_dimension), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n\n    # Pandas DataFrame natually supports 1 dim data only.\n    label_dimension = 1\n    input_dimension = label_dimension\n    batch_size = 10\n    data = np.array([1., 2., 3., 4.], dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(data)\n    prediction_length = 4\n\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(\n                              value=datum[:label_dimension])),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n\nclass BaselineRegressorTrainingTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _assert_checkpoint(self,\n                         label_dimension,\n                         expected_global_step,\n                         expected_bias=None):\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([label_dimension], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertEqual(expected_bias,\n                       tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def testFromScratchWithDefaultOptimizer(self):\n    # Create BaselineRegressor.\n    label = 5.\n    age = 17\n    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    baseline_regressor.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self._assert_checkpoint(label_dimension=1, expected_global_step=num_steps)\n\n  def testTrainWithOneDimLabel(self):\n    label_dimension = 1\n    batch_size = 20\n    est = _baseline_regressor_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)\n    self.assertEqual((batch_size,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(label_dimension=1, expected_global_step=200)\n\n  def testTrainWithOneDimWeight(self):\n    label_dimension = 1\n    batch_size = 20\n    est = _baseline_regressor_fn(\n        label_dimension=label_dimension,\n        weight_column='w',\n        model_dir=self._model_dir)\n\n    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)\n    self.assertEqual((batch_size,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_1\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(label_dimension=1, expected_global_step=200)\n\n  def testFromScratch(self):\n    # Create BaselineRegressor.\n    label = 5.\n    age = 17\n    # loss = (logits - label)^2 = (0 - 5.)^2 = 25.\n    mock_optimizer = mock_optimizer_v2(self, expected_loss=25.)\n    baseline_regressor = _baseline_regressor_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_regressor.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(\n        num_steps,\n        baseline_regressor.get_variable_value(mock_optimizer.iterations.name))\n    self._assert_checkpoint(\n        label_dimension=1, expected_global_step=num_steps, expected_bias=[0.])\n\n  def testFromCheckpoint(self):\n    # Create initial checkpoint.\n    bias = 7.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = bias = 6.\n    # loss = (logits - label)^2 = (7 - 5)^2 = 4\n    mock_optimizer = mock_optimizer_v2(self, expected_loss=4.)\n    baseline_regressor = _baseline_regressor_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_regressor.train(\n        input_fn=lambda: ({\n            'age': ((17,),)\n        }, ((5.,),)), steps=num_steps)\n    self.assertEqual(\n        initial_global_step + num_steps,\n        baseline_regressor.get_variable_value(mock_optimizer.iterations.name))\n    self._assert_checkpoint(\n        label_dimension=1,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=[bias])\n\n  def testFromCheckpointMultiBatch(self):\n    # Create initial checkpoint.\n    bias = 5.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = bias\n    # logits[0] = 5.\n    # logits[1] = 5.\n    # loss = (sum(logits - label)^2 = (5 - 5)^2 + (5 - 3)^2) / 2 (batch size)\n    # loss = 2\n    mock_optimizer = mock_optimizer_v2(self, expected_loss=2.)\n    baseline_regressor = _baseline_regressor_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_regressor.train(\n        input_fn=lambda: ({\n            'age': ((17,), (15,))\n        }, ((5.,), (3.,))),\n        steps=num_steps)\n    self.assertEqual(\n        initial_global_step + num_steps,\n        baseline_regressor.get_variable_value(mock_optimizer.iterations.name))\n    self._assert_checkpoint(\n        label_dimension=1,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=bias)\n\n\n# Tests for Baseline Classifier.\n\n\nclass BaselineClassifierTrainingTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _assert_checkpoint(self,\n                         n_classes,\n                         expected_global_step,\n                         expected_bias=None):\n    logits_dimension = n_classes if n_classes > 2 else 1\n\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([logits_dimension], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertAllEqual(expected_bias,\n                          tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def _testFromScratchWithDefaultOptimizer(self, n_classes):\n    label = 0\n    age = 17\n    est = baseline.BaselineClassifierV2(\n        n_classes=n_classes, model_dir=self._model_dir)\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self._assert_checkpoint(n_classes, num_steps)\n\n  def testBinaryClassesFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=2)\n\n  def testMultiClassesFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=4)\n\n  def _testTrainWithTwoDimsLabel(self, n_classes):\n    batch_size = 20\n\n    est = baseline.BaselineClassifierV2(\n        n_classes=n_classes, model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    data_rank_2 = np.array([[0], [1]])\n    self.assertEqual((2,), data_rank_1.shape)\n    self.assertEqual((2, 1), data_rank_2.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_2,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithTwoDimsLabel(self):\n    self._testTrainWithTwoDimsLabel(n_classes=2)\n\n  def testMultiClassesTrainWithTwoDimsLabel(self):\n    self._testTrainWithTwoDimsLabel(n_classes=4)\n\n  def _testTrainWithOneDimLabel(self, n_classes):\n    batch_size = 20\n\n    est = baseline.BaselineClassifierV2(\n        n_classes=n_classes, model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    self.assertEqual((2,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithOneDimLabel(self):\n    self._testTrainWithOneDimLabel(n_classes=2)\n\n  def testMultiClassesTrainWithOneDimLabel(self):\n    self._testTrainWithOneDimLabel(n_classes=4)\n\n  def _testTrainWithTwoDimsWeight(self, n_classes):\n    batch_size = 20\n\n    est = baseline.BaselineClassifierV2(\n        weight_column='w', n_classes=n_classes, model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    data_rank_2 = np.array([[0], [1]])\n    self.assertEqual((2,), data_rank_1.shape)\n    self.assertEqual((2, 1), data_rank_2.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_2\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithTwoDimsWeight(self):\n    self._testTrainWithTwoDimsWeight(n_classes=2)\n\n  def testMultiClassesTrainWithTwoDimsWeight(self):\n    self._testTrainWithTwoDimsWeight(n_classes=4)\n\n  def _testTrainWithOneDimWeight(self, n_classes):\n    batch_size = 20\n\n    est = baseline.BaselineClassifierV2(\n        weight_column='w', n_classes=n_classes, model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    self.assertEqual((2,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_1\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithOneDimWeight(self):\n    self._testTrainWithOneDimWeight(n_classes=2)\n\n  def testMultiClassesTrainWithOneDimWeight(self):\n    self._testTrainWithOneDimWeight(n_classes=4)\n\n  def _testFromScratch(self, n_classes):\n    label = 1\n    age = 17\n    # For binary classifier:\n    #   loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are\n    #   all zero initially) and label = 1 so,\n    #      loss = 1 * -log ( sigmoid(logits) ) = 0.69315\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label) where logits are all 0s (weights are\n    #   all zero initially) and label = 1 so,\n    #      loss = 1 * -log ( 1.0 / n_classes )\n    # For this particular test case, as logits are same, the formula\n    # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases.\n    mock_optimizer = mock_optimizer_v2(\n        self, expected_loss=-1 * math.log(1.0 / n_classes))\n\n    est = baseline.BaselineClassifierV2(\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(num_steps,\n                     est.get_variable_value(mock_optimizer.iterations.name))\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=num_steps,\n        expected_bias=[0.] if n_classes == 2 else [.0] * n_classes)\n\n  def testBinaryClassesFromScratch(self):\n    self._testFromScratch(n_classes=2)\n\n  def testMultiClassesFromScratch(self):\n    self._testFromScratch(n_classes=4)\n\n  def _testFromCheckpoint(self, n_classes):\n    # Create initial checkpoint.\n    label = 1\n    age = 17\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # For binary classifier:\n    #   logits = bias = -1.\n    #   loss = sigmoid_cross_entropy(logits, label)\n    #   so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label)\n    #   where logits = bias and label = 1\n    #   so, loss = 1 * -log ( softmax(logits)[1] )\n    if n_classes == 2:\n      expected_loss = 1.3133\n    else:\n      logits = bias\n      logits_exp = np.exp(logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_loss = -1 * math.log(softmax[label])\n\n    mock_optimizer = mock_optimizer_v2(self, expected_loss=expected_loss)\n\n    est = baseline.BaselineClassifierV2(\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(initial_global_step + num_steps,\n                     est.get_variable_value(mock_optimizer.iterations.name))\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=bias)\n\n  def testBinaryClassesFromCheckpoint(self):\n    self._testFromCheckpoint(n_classes=2)\n\n  def testMultiClassesFromCheckpoint(self):\n    self._testFromCheckpoint(n_classes=4)\n\n  def _testFromCheckpointFloatLabels(self, n_classes):\n    \"\"\"Tests float labels for binary classification.\"\"\"\n    # Create initial checkpoint.\n    if n_classes > 2:\n      return\n    label = 0.8\n    age = 17\n    bias = [-1.0]\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = bias = -1.\n    # loss = sigmoid_cross_entropy(logits, label)\n    # => loss = -0.8 * log(sigmoid(-1)) -0.2 * log(sigmoid(+1)) = 1.1132617\n    mock_optimizer = mock_optimizer_v2(self, expected_loss=1.1132617)\n\n    est = baseline.BaselineClassifierV2(\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(initial_global_step + num_steps,\n                     est.get_variable_value(mock_optimizer.iterations.name))\n\n  def testBinaryClassesFromCheckpointFloatLabels(self):\n    self._testFromCheckpointFloatLabels(n_classes=2)\n\n  def testMultiClassesFromCheckpointFloatLabels(self):\n    self._testFromCheckpointFloatLabels(n_classes=4)\n\n  def _testFromCheckpointMultiBatch(self, n_classes):\n    # Create initial checkpoint.\n    label = [1, 0]\n    age = [17, 18.5]\n    batch_size = 2\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # For binary classifier:\n    #   logits = bias\n    #   logits[0] = -1.\n    #   logits[1] = -1.\n    #   loss = sigmoid_cross_entropy(logits, label)\n    #   so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133\n    #       loss[1] = (1 - 0) * -log ( 1- sigmoid(-1) ) = 0.3132\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label)\n    #   where logits = bias and label = [1, 0]\n    #   so, loss = 1 * -log ( softmax(logits)[label] )\n    if n_classes == 2:\n      expected_loss = (1.3133 + 0.3132) / 2\n    else:\n      # Expand logits since batch_size=2\n      logits = bias * np.ones(shape=(2, 1))\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      expected_loss = (expected_loss_0 + expected_loss_1) / 2\n\n    mock_optimizer = mock_optimizer_v2(self, expected_loss=expected_loss)\n\n    est = baseline.BaselineClassifierV2(\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(input_fn=lambda: ({'age': (age)}, (label)), steps=num_steps)\n    self.assertEqual(initial_global_step + num_steps,\n                     est.get_variable_value(mock_optimizer.iterations.name))\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=bias)\n\n  def testBinaryClassesFromCheckpointMultiBatch(self):\n    self._testFromCheckpointMultiBatch(n_classes=2)\n\n  def testMultiClassesFromCheckpointMultiBatch(self):\n    self._testFromCheckpointMultiBatch(n_classes=4)\n\n\nclass BaselineClassifierEvaluationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _test_evaluation_for_simple_data(self, n_classes):\n    label = 1\n    age = 1.\n\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = _baseline_classifier_fn(\n        n_classes=n_classes, model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=1)\n\n    if n_classes == 2:\n      # Binary classes: loss = -log(sigmoid(-1)) / batch size = 1.3133\n      # Prediction = sigmoid(-1) = 0.2689\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: 1.3133,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: 1.3133,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,\n          metric_keys.MetricKeys.LABEL_MEAN: 1.,\n          metric_keys.MetricKeys.ACCURACY_BASELINE: 1,\n          metric_keys.MetricKeys.AUC: 0.,\n          metric_keys.MetricKeys.AUC_PR: 1.,\n      }\n    else:\n      # Multi classes: loss = 1 * -log ( softmax(logits)[label] )\n      logits = bias\n      logits_exp = np.exp(logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_loss = -1 * math.log(softmax[label])\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_for_simple_data(self):\n    self._test_evaluation_for_simple_data(n_classes=2)\n\n  def test_multi_classes_evaluation_for_simple_data(self):\n    self._test_evaluation_for_simple_data(n_classes=4)\n\n  def _test_evaluation_batch(self, n_classes):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    label = [1, 0]\n    age = [17., 18.]\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = _baseline_classifier_fn(\n        n_classes=n_classes, model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({\n            'age': (age)\n        }, (label)), steps=1)\n\n    if n_classes == 2:\n      # Logits are (-1., -1.) labels are (1, 0).\n      # Loss is\n      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133\n      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132\n      # Prediction = sigmoid(-1) = 0.2689\n      expected_loss = (1.3133 + 0.3132) / 2  # batch size\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n          metric_keys.MetricKeys.ACCURACY: 0.5,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,\n          metric_keys.MetricKeys.LABEL_MEAN: 0.5,\n          metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,\n          metric_keys.MetricKeys.AUC: 0.5,\n          metric_keys.MetricKeys.AUC_PR: 0.5,\n      }\n    else:\n      # Expand logits since batch_size=2\n      logits = bias * np.ones(shape=(2, 1))\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      expected_loss = (expected_loss_0 + expected_loss_1) / 2  # batch size\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n          metric_keys.MetricKeys.ACCURACY: 0.5,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_batch(self):\n    self._test_evaluation_batch(n_classes=2)\n\n  def test_multi_classes_evaluation_batch(self):\n    self._test_evaluation_batch(n_classes=4)\n\n  def _test_evaluation_weights(self, n_classes):\n    \"\"\"Tests evaluation with weights.\"\"\"\n\n    label = [1, 0]\n    age = [17., 18.]\n    weights = [1., 2.]\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = _baseline_classifier_fn(\n        n_classes=n_classes, weight_column='w', model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({\n            'age': (age),\n            'w': (weights)\n        }, (label)), steps=1)\n\n    if n_classes == 2:\n      # Logits are (-1., -1.) labels are (1, 0).\n      # Loss is\n      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133\n      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132\n      #   weights = [1., 2.]\n      expected_loss = (1.3133 * 1. + 0.3132 * 2.) / 2  # batch size\n      loss_mean = (1.3133 * 1. + 0.3132 * 2.) / (1.0 + 2.0)\n      label_mean = np.average(label, weights=weights)\n      logits = [-1, -1]\n      logistics = sigmoid(np.array(logits))\n      predictions_mean = np.average(logistics, weights=weights)\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,\n          metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.),\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean,\n          metric_keys.MetricKeys.LABEL_MEAN: label_mean,\n          metric_keys.MetricKeys.ACCURACY_BASELINE:\n              (max(label_mean, 1 - label_mean)),\n          metric_keys.MetricKeys.AUC: 0.5,\n          metric_keys.MetricKeys.AUC_PR: 0.33333,\n      }\n    else:\n      # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )\n      # Expand logits since batch_size=2\n      logits = bias * np.ones(shape=(2, 1))\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      loss_mean = np.average([expected_loss_0, expected_loss_1],\n                             weights=weights)\n      expected_loss = (loss_mean * np.sum(weights)) / 2  # batch size\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,\n          metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.),\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_weights(self):\n    self._test_evaluation_weights(n_classes=2)\n\n  def test_multi_classes_evaluation_weights(self):\n    self._test_evaluation_weights(n_classes=4)\n\n\nclass BaselineClassifierPredictTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _testPredictions(self, n_classes, label_vocabulary, label_output_fn):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    age = 1.\n\n    bias = [10.0] if n_classes == 2 else [10.0] * n_classes\n\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = _baseline_classifier_fn(\n        label_vocabulary=label_vocabulary,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'age': np.array([[age]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = list(est.predict(input_fn=predict_input_fn))\n\n    if n_classes == 2:\n      scalar_logits = bias[0]\n      two_classes_logits = [0, scalar_logits]\n      two_classes_logits_exp = np.exp(two_classes_logits)\n      softmax = two_classes_logits_exp / two_classes_logits_exp.sum()\n\n      expected_predictions = {\n          'class_ids': [1],\n          'all_class_ids': [0, 1],\n          'classes': [label_output_fn(1)],\n          'all_classes': [label_output_fn(0),\n                          label_output_fn(1)],\n          'logistic': [sigmoid(np.array(scalar_logits))],\n          'logits': [scalar_logits],\n          'probabilities': softmax,\n      }\n    else:\n      onedim_logits = np.array(bias)\n      class_ids = onedim_logits.argmax()\n      all_class_ids = list(range(len(onedim_logits)))\n      logits_exp = np.exp(onedim_logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_predictions = {\n          'class_ids': [class_ids],\n          'all_class_ids': all_class_ids,\n          'classes': [label_output_fn(class_ids)],\n          'all_classes': [label_output_fn(i) for i in all_class_ids],\n          'logits': onedim_logits,\n          'probabilities': softmax,\n      }\n\n    self.assertEqual(1, len(predictions))\n    # assertAllClose cannot handle byte type.\n    self.assertEqual(expected_predictions['classes'], predictions[0]['classes'])\n    expected_predictions.pop('classes')\n    predictions[0].pop('classes')\n    self.assertAllEqual(expected_predictions['all_classes'],\n                        predictions[0]['all_classes'])\n    expected_predictions.pop('all_classes')\n    predictions[0].pop('all_classes')\n    self.assertAllClose(\n        sorted_key_dict(expected_predictions), sorted_key_dict(predictions[0]))\n\n  def testBinaryClassesWithoutLabelVocabulary(self):\n    n_classes = 2\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=None,\n        label_output_fn=lambda x: ('%s' % x).encode())\n\n  def testBinaryClassesWithLabelVocabulary(self):\n    n_classes = 2\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n  def testMultiClassesWithoutLabelVocabulary(self):\n    n_classes = 4\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=None,\n        label_output_fn=lambda x: ('%s' % x).encode())\n\n  def testMultiClassesWithLabelVocabulary(self):\n    n_classes = 4\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n\nclass BaselineClassifierIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,\n                          predict_input_fn, input_dimension, prediction_length):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = _baseline_classifier_fn(\n        n_classes=n_classes, model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['classes'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, 1), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _test_numpy_input_fn(self, n_classes):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    input_dimension = 4\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n    target = np.array([1] * batch_size)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=target,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=target,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_numpy_input_fn(self):\n    self._test_numpy_input_fn(n_classes=2)\n\n  def test_multi_classes_numpy_input_fn(self):\n    self._test_numpy_input_fn(n_classes=4)\n\n  def _test_pandas_input_fn(self, n_classes):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n\n    # Pandas DataFrame natually supports 1 dim data only.\n    input_dimension = 1\n    batch_size = 10\n    data = np.array([1., 2., 3., 4.], dtype=np.float32)\n    target = np.array([1, 0, 1, 0], dtype=np.int32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(target)\n    prediction_length = 4\n\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_pandas_input_fn(self):\n    self._test_pandas_input_fn(n_classes=2)\n\n  def test_multi_classes_pandas_input_fn(self):\n    self._test_pandas_input_fn(n_classes=4)\n\n  def _test_input_fn_from_parse_example(self, n_classes):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    input_dimension = 2\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n    target = np.array([1] * batch_size, dtype=np.int64)\n\n    serialized_examples = []\n    for x, y in zip(data, target):\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=x)),\n                  'y':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(value=[y])),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([1], tf.dtypes.int64),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_input_fn_from_parse_example(self):\n    self._test_input_fn_from_parse_example(n_classes=2)\n\n  def test_multi_classes_input_fn_from_parse_example(self):\n    self._test_input_fn_from_parse_example(n_classes=4)\n\n\n# Tests for Baseline logit_fn.\n\n\nclass BaselineLogitFnTest(tf.test.TestCase):\n\n  def test_basic_logit_correctness(self):\n    \"\"\"baseline_logit_fn simply returns the bias variable.\"\"\"\n    with tf.Graph().as_default():\n      bias_var, logits = baseline._baseline_model_fn_builder_v2(\n          features={'age': [[23.], [31.]]}, num_outputs=2)\n      with tf.compat.v1.Session() as sess:\n        sess.run([tf.compat.v1.initializers.global_variables()])\n        self.assertAllClose([[0., 0.], [0., 0.]], logits.eval())\n        sess.run(bias_var[0].assign([10., 5.]))\n        self.assertAllClose([[10., 5.], [10., 5.]], logits.eval())\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/canned_estimator_ds_integration_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests canned estimators with distribution strategy.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport inspect\nimport tempfile\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import run_config\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import dnn_linear_combined\nfrom tensorflow_estimator.python.estimator.canned import linear\nfrom tensorflow_estimator.python.estimator.extenders import add_metrics\n\n\nclass CannedEstimatorDistributionStrategyTest(tf.test.TestCase,\n                                              parameterized.TestCase):\n\n  def setUp(self):\n    super(CannedEstimatorDistributionStrategyTest, self).setUp()\n    np.random.seed(1337)\n    tf.compat.v1.random.set_random_seed(1337)\n\n    self._model_dir = tempfile.mkdtemp()\n\n  def dataset_input_fn(self, x, y, batch_size, shuffle):\n\n    def input_fn():\n      dataset = tf.compat.v1.data.Dataset.from_tensor_slices((x, y))\n      if shuffle:\n        dataset = dataset.shuffle(batch_size)\n      dataset = dataset.repeat(10).batch(batch_size)\n      return dataset\n\n    return input_fn\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          mode=['graph', 'eager'],\n          distribution=[\n              tf.compat.v2.__internal__.distribute.combinations.one_device_strategy,\n              tf.compat.v2.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu,\n              tf.compat.v2.__internal__.distribute.combinations.mirrored_strategy_with_two_gpus,\n          ],\n          estimator_cls=[\n              dnn_linear_combined.DNNLinearCombinedRegressorV2,\n              dnn.DNNRegressorV2,\n              linear.LinearRegressorV2,\n          ]))\n  def test_canned_estimator(self, distribution, estimator_cls):\n    label_dimension = 2\n    batch_size = 10\n    # Adding one extra row (+ label_dimension) to test the last partial batch\n    # use case.\n    data = np.linspace(\n        0.,\n        2.,\n        batch_size * label_dimension + label_dimension,\n        dtype=np.float32)\n    data = data.reshape(batch_size + 1, label_dimension)\n    fc = tf.feature_column.numeric_column('x', shape=(2,))\n\n    # Set kwargs based on the current canned estimator class.\n    estimator_kw_args = {\n        'model_dir': self._model_dir,\n        'label_dimension': 2,\n    }\n\n    cls_args = inspect.getargspec(estimator_cls.__init__).args\n    if 'hidden_units' in cls_args:\n      estimator_kw_args['hidden_units'] = [2, 2]\n    elif 'dnn_hidden_units' in cls_args:\n      estimator_kw_args['dnn_hidden_units'] = [2, 2]\n\n    if 'optimizer' in cls_args:\n      estimator_kw_args['optimizer'] = 'SGD'\n    else:\n      estimator_kw_args['linear_optimizer'] = 'SGD'\n      estimator_kw_args['dnn_optimizer'] = 'SGD'\n\n    if 'feature_columns' in cls_args:\n      estimator_kw_args['feature_columns'] = [fc]\n    else:\n      estimator_kw_args['linear_feature_columns'] = [fc]\n      estimator_kw_args['dnn_feature_columns'] = [fc]\n\n    def my_metrics(features):\n      metric = tf_keras.metrics.Mean()\n      metric.update_state(features['x'])\n      return {'mean_x': metric}\n\n    # Create a canned estimator and train to save a checkpoint.\n    input_fn = self.dataset_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    canned_est = estimator_cls(**estimator_kw_args)\n    canned_est.train(input_fn=input_fn)\n\n    # Create a second canned estimator, warm-started from the first.\n    del estimator_kw_args['model_dir']\n    estimator_kw_args['warm_start_from'] = canned_est.model_dir\n    warm_started_canned_est = estimator_cls(**estimator_kw_args)\n    warm_started_canned_est.train(input_fn=input_fn)\n\n    # Create a third canned estimator, warm-started from the first.\n    input_fn = self.dataset_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size // distribution.num_replicas_in_sync,\n        shuffle=False)\n    estimator_kw_args['config'] = run_config.RunConfig(\n        train_distribute=distribution, eval_distribute=distribution)\n    warm_started_canned_est_with_ds = estimator_cls(**estimator_kw_args)\n    warm_started_canned_est_with_ds.train(input_fn=input_fn)\n\n    for variable_name in warm_started_canned_est.get_variable_names():\n      self.assertAllClose(\n          warm_started_canned_est_with_ds.get_variable_value(variable_name),\n          warm_started_canned_est.get_variable_value(variable_name))\n\n    warm_started_canned_est = add_metrics(warm_started_canned_est, my_metrics)\n    warm_started_canned_est_with_ds = add_metrics(\n        warm_started_canned_est_with_ds, my_metrics)\n\n    scores = warm_started_canned_est.evaluate(input_fn)\n    scores_with_ds = warm_started_canned_est_with_ds.evaluate(input_fn)\n    self.assertAlmostEqual(scores['loss'], scores_with_ds['loss'], 5)\n    self.assertAlmostEqual(scores['mean_x'], scores_with_ds['mean_x'], 5)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/dnn.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Deep Neural Network estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.feature_column import feature_column_lib\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import optimizers\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.head import head_utils\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n# The default learning rate of 0.05 is a historical artifact of the initial\n# implementation, but seems a reasonable choice.\n_LEARNING_RATE = 0.05\n\n\ndef _add_hidden_layer_summary(value, tag):\n  tf.compat.v1.summary.scalar('%s/fraction_of_zero_values' % tag,\n                              tf.math.zero_fraction(value))\n  tf.compat.v1.summary.histogram('%s/activation' % tag, value)\n\n\n@estimator_export(v1=['estimator.experimental.dnn_logit_fn_builder'])\ndef dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,\n                         dropout, input_layer_partitioner, batch_norm):\n  \"\"\"Function builder for a dnn logit_fn.\n\n  Args:\n    units: An int indicating the dimension of the logit layer.  In the MultiHead\n      case, this should be the sum of all component Heads' logit dimensions.\n    hidden_units: Iterable of integer number of hidden units per layer.\n    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.\n    activation_fn: Activation function applied to each layer.\n    dropout: When not `None`, the probability we will drop out a given\n      coordinate.\n    input_layer_partitioner: Partitioner for input layer.\n    batch_norm: Whether to use batch normalization after each hidden layer.\n\n  Returns:\n    A logit_fn (see below).\n\n  Raises:\n    ValueError: If units is not an int.\n  \"\"\"\n  if not isinstance(units, six.integer_types):\n    raise ValueError('units must be an int.  Given type: {}'.format(\n        type(units)))\n\n  def dnn_logit_fn(features, mode):\n    \"\"\"Deep Neural Network logit_fn.\n\n    Args:\n      features: This is the first item returned from the `input_fn` passed to\n        `train`, `evaluate`, and `predict`. This should be a single `Tensor` or\n        `dict` of same.\n      mode: Optional. Specifies if this training, evaluation or prediction. See\n        `ModeKeys`.\n\n    Returns:\n      A `Tensor` representing the logits, or a list of `Tensor`'s representing\n      multiple logits in the MultiHead case.\n    \"\"\"\n    dnn_model = _DNNModel(\n        units,\n        hidden_units,\n        feature_columns,\n        activation_fn,\n        dropout,\n        input_layer_partitioner,\n        batch_norm,\n        name='dnn')\n    return dnn_model(features, mode)\n\n  return dnn_logit_fn\n\n\ndef dnn_logit_fn_builder_v2(units, hidden_units, feature_columns, activation_fn,\n                            dropout, batch_norm):\n  \"\"\"Function builder for a dnn logit_fn.\n\n  Args:\n    units: An int indicating the dimension of the logit layer.  In the MultiHead\n      case, this should be the sum of all component Heads' logit dimensions.\n    hidden_units: Iterable of integer number of hidden units per layer.\n    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.\n    activation_fn: Activation function applied to each layer.\n    dropout: When not `None`, the probability we will drop out a given\n      coordinate.\n    batch_norm: Whether to use batch normalization after each hidden layer.\n\n  Returns:\n    A logit_fn (see below).\n\n  Raises:\n    ValueError: If units is not an int.\n  \"\"\"\n  if not isinstance(units, six.integer_types):\n    raise ValueError('units must be an int.  Given type: {}'.format(\n        type(units)))\n\n  def dnn_logit_fn(features, mode):\n    \"\"\"Deep Neural Network logit_fn.\n\n    Args:\n      features: This is the first item returned from the `input_fn` passed to\n        `train`, `evaluate`, and `predict`. This should be a single `Tensor` or\n        `dict` of same.\n      mode: Optional. Specifies if this training, evaluation or prediction. See\n        `ModeKeys`.\n\n    Returns:\n      A `Tensor` representing the logits, or a list of `Tensor`'s representing\n      multiple logits in the MultiHead case.\n    \"\"\"\n    dnn_model = _DNNModelV2(\n        units,\n        hidden_units,\n        feature_columns,\n        activation_fn,\n        dropout,\n        batch_norm,\n        name='dnn')\n    return dnn_model(features, mode)\n\n  return dnn_logit_fn\n\n\ndef _get_previous_name_scope():\n  current_name_scope = tf.compat.v2.__internal__.get_name_scope()\n  return current_name_scope.rsplit('/', 1)[0] + '/'\n\n\nclass _DNNModel(tf_keras.Model):\n  \"\"\"A DNN Model.\"\"\"\n\n  def __init__(self,\n               units,\n               hidden_units,\n               feature_columns,\n               activation_fn,\n               dropout,\n               input_layer_partitioner,\n               batch_norm,\n               name=None,\n               **kwargs):\n    super(_DNNModel, self).__init__(name=name, **kwargs)\n    if feature_column_lib.is_feature_column_v2(feature_columns):\n      self._input_layer = tf_keras_v1.layers.DenseFeatures(\n          feature_columns=feature_columns, name='input_layer')\n    else:\n      self._input_layer = feature_column.InputLayer(\n          feature_columns=feature_columns,\n          name='input_layer',\n          create_scope_now=False)\n\n    self._add_layer(self._input_layer, 'input_layer')\n\n    self._dropout = dropout\n    self._batch_norm = batch_norm\n\n    self._hidden_layers = []\n    self._dropout_layers = []\n    self._batch_norm_layers = []\n    self._hidden_layer_scope_names = []\n    for layer_id, num_hidden_units in enumerate(hidden_units):\n      with tf.compat.v1.variable_scope(\n          'hiddenlayer_%d' % layer_id\n      ) as hidden_layer_scope:\n        hidden_layer = tf_keras_v1.__internal__.legacy.layers.Dense(\n            units=num_hidden_units,\n            activation=activation_fn,\n            kernel_initializer=tf.compat.v1.glorot_uniform_initializer(),\n            name=hidden_layer_scope,\n            _scope=hidden_layer_scope,\n        )\n        self._add_layer(hidden_layer, hidden_layer_scope.name)\n        self._hidden_layer_scope_names.append(hidden_layer_scope.name)\n        self._hidden_layers.append(hidden_layer)\n        if self._dropout is not None:\n          dropout_layer = tf_keras_v1.__internal__.legacy.layers.Dropout(\n              rate=self._dropout\n          )\n          self._add_layer(dropout_layer, dropout_layer.name)\n          self._dropout_layers.append(dropout_layer)\n        if self._batch_norm:\n          batch_norm_layer = tf_keras_v1.__internal__.legacy.layers.BatchNormalization(\n              # The default momentum 0.99 actually crashes on certain\n              # problem, so here we use 0.999, which is the default of\n              # tf.contrib.layers.batch_norm.\n              momentum=0.999,\n              trainable=True,\n              name='batchnorm_%d' % layer_id,\n              _scope='batchnorm_%d' % layer_id)\n          self._add_layer(batch_norm_layer, batch_norm_layer.name)\n          self._batch_norm_layers.append(batch_norm_layer)\n\n    with tf.compat.v1.variable_scope('logits') as logits_scope:\n      self._logits_layer = tf_keras_v1.__internal__.legacy.layers.Dense(\n          units=units,\n          activation=None,\n          kernel_initializer=tf.compat.v1.glorot_uniform_initializer(),\n          name=logits_scope,\n          _scope=logits_scope)\n      self._add_layer(self._logits_layer, logits_scope.name)\n      self._logits_scope_name = logits_scope.name\n    self._input_layer_partitioner = input_layer_partitioner\n\n  def call(self, features, mode):\n    is_training = mode == ModeKeys.TRAIN\n    # The Keras training.Model adds a name_scope with the name of the model\n    # which modifies the constructed graph. Hence we add another name_scope\n    # here which is the one before the training.Model one was applied.\n    # TODO(rohanj): Remove this in TF 2.0 (b/116728605)\n    with ops.name_scope(name=_get_previous_name_scope()):\n      # TODO(rohanj): Remove dependence on variable scope for partitioning.\n      with tf.compat.v1.variable_scope(\n          'input_from_feature_columns',\n          partitioner=self._input_layer_partitioner):\n        try:\n          net = self._input_layer(features, training=is_training)\n        except TypeError:\n          net = self._input_layer(features)\n      for i in range(len(self._hidden_layers)):\n        net = self._hidden_layers[i](net)\n        if self._dropout is not None and is_training:\n          net = self._dropout_layers[i](net, training=True)\n        if self._batch_norm:\n          net = self._batch_norm_layers[i](net, training=is_training)\n        _add_hidden_layer_summary(net, self._hidden_layer_scope_names[i])\n\n      logits = self._logits_layer(net)\n      _add_hidden_layer_summary(logits, self._logits_scope_name)\n      return logits\n\n  def _add_layer(self, layer, layer_name):\n    # \"Magic\" required for keras.Model classes to track all the variables in\n    # a list of layers.Layer objects.\n    # TODO(ashankar): Figure out API so user code doesn't have to do this.\n    setattr(self, layer_name, layer)\n\n\ndef _name_from_scope_name(name):\n  \"\"\"Returns the name of an op given the name of its scope.\n\n  Args:\n    name: the name of the scope.\n\n  Returns:\n    the name of the op (equal to scope name minus any trailing slash).\n  \"\"\"\n  return name[:-1] if (name and name[-1] == '/') else name\n\n\nclass _DNNModelV2(tf_keras.Model):\n  \"\"\"A DNN Model.\"\"\"\n\n  def __init__(self,\n               units,\n               hidden_units,\n               feature_columns,\n               activation_fn,\n               dropout,\n               batch_norm,\n               name=None,\n               **kwargs):\n    super(_DNNModelV2, self).__init__(name=name, **kwargs)\n    with ops.name_scope(\n        'input_from_feature_columns') as input_feature_column_scope:\n      layer_name = input_feature_column_scope + 'input_layer'\n      if feature_column_lib.is_feature_column_v2(feature_columns):\n        self._input_layer = tf_keras.layers.DenseFeatures(\n            feature_columns=feature_columns, name=layer_name)\n      else:\n        raise ValueError(\n            'Received a feature column from TensorFlow v1, but this is a '\n            'TensorFlow v2 Estimator. Please either use v2 feature columns '\n            '(accessible via tf.feature_column.* in TF 2.x) with this '\n            'Estimator, or switch to a v1 Estimator for use with v1 feature '\n            'columns (accessible via tf.compat.v1.estimator.* and '\n            'tf.compat.v1.feature_column.*, respectively.')\n\n    self._dropout = dropout\n    self._batch_norm = batch_norm\n\n    self._hidden_layers = []\n    self._dropout_layers = []\n    self._batch_norm_layers = []\n    self._hidden_layer_scope_names = []\n    for layer_id, num_hidden_units in enumerate(hidden_units):\n      with ops.name_scope('hiddenlayer_%d' % layer_id) as hidden_layer_scope:\n        # Get scope name without the trailing slash.\n        hidden_shared_name = _name_from_scope_name(hidden_layer_scope)\n        hidden_layer = tf_keras.layers.Dense(\n            units=num_hidden_units,\n            activation=activation_fn,\n            kernel_initializer=tf.compat.v1.glorot_uniform_initializer(),\n            name=hidden_shared_name)\n        self._hidden_layer_scope_names.append(hidden_shared_name)\n        self._hidden_layers.append(hidden_layer)\n        if self._dropout is not None:\n          dropout_layer = tf_keras.layers.Dropout(rate=self._dropout)\n          self._dropout_layers.append(dropout_layer)\n        if self._batch_norm:\n          batch_norm_name = hidden_shared_name + '/batchnorm_%d' % layer_id\n          # TODO(scottzhu): Change back to use BatchNormalization when the\n          # cleanup is done.\n          batch_norm_layer = tf_keras.layers.BatchNormalization(\n              # The default momentum 0.99 actually crashes on certain\n              # problem, so here we use 0.999, which is the default of\n              # tf.contrib.layers.batch_norm.\n              momentum=0.999,\n              trainable=True,\n              name=batch_norm_name)\n          self._batch_norm_layers.append(batch_norm_layer)\n\n    with ops.name_scope('logits') as logits_scope:\n      logits_shared_name = _name_from_scope_name(logits_scope)\n      self._logits_layer = tf_keras.layers.Dense(\n          units=units,\n          activation=None,\n          kernel_initializer=tf.compat.v1.glorot_uniform_initializer(),\n          name=logits_shared_name)\n      self._logits_scope_name = logits_shared_name\n\n  def call(self, features, mode):\n    is_training = mode == ModeKeys.TRAIN\n    try:\n      net = self._input_layer(features, training=is_training)\n    except TypeError:\n      net = self._input_layer(features)\n    for i in range(len(self._hidden_layers)):\n      net = self._hidden_layers[i](net)\n      if self._dropout is not None and is_training:\n        net = self._dropout_layers[i](net, training=True)\n      if self._batch_norm:\n        net = self._batch_norm_layers[i](net, training=is_training)\n      _add_hidden_layer_summary(net, self._hidden_layer_scope_names[i])\n\n    logits = self._logits_layer(net)\n    _add_hidden_layer_summary(logits, self._logits_scope_name)\n    return logits\n\n\ndef _validate_features(features):\n  if not isinstance(features, dict):\n    raise ValueError('features should be a dictionary of `Tensor`s. '\n                     'Given type: {}'.format(type(features)))\n\n\ndef _get_dnn_estimator_spec(use_tpu, head, features, labels, mode, logits,\n                            optimizer):\n  \"\"\"Get EstimatorSpec for DNN Model.\"\"\"\n  if use_tpu:\n    return head._create_tpu_estimator_spec(  # pylint: disable=protected-access\n        features=features,\n        mode=mode,\n        labels=labels,\n        optimizer=optimizer,\n        logits=logits)\n  else:\n    return head.create_estimator_spec(\n        features=features,\n        mode=mode,\n        labels=labels,\n        optimizer=optimizer,\n        logits=logits)\n\n\ndef _dnn_model_fn(features,\n                  labels,\n                  mode,\n                  head,\n                  hidden_units,\n                  feature_columns,\n                  optimizer='Adagrad',\n                  activation_fn=tf.nn.relu,\n                  dropout=None,\n                  input_layer_partitioner=None,\n                  config=None,\n                  use_tpu=False,\n                  batch_norm=False):\n  \"\"\"Deep Neural Net model_fn v1.\n\n  Args:\n    features: dict of `Tensor`.\n    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype\n      `int32` or `int64` in the range `[0, n_classes)`.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `head_lib._Head` instance.\n    hidden_units: Iterable of integer number of hidden units per layer.\n    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.\n    optimizer: String, `tf.Optimizer` object, or callable that creates the\n      optimizer to use for training. If not specified, will use the Adagrad\n      optimizer with a default learning rate of 0.05.\n    activation_fn: Activation function applied to each layer.\n    dropout: When not `None`, the probability we will drop out a given\n      coordinate.\n    input_layer_partitioner: Partitioner for input layer. Defaults to\n      `min_max_variable_partitioner` with `min_slice_size` 64 << 20.\n    config: `RunConfig` object to configure the runtime settings.\n    use_tpu: Whether to make a DNN model able to run on TPU. Will make function\n      return a `_TPUEstimatorSpec` instance and disable variable partitioning.\n    batch_norm: Whether to use batch normalization after each hidden layer.\n\n  Returns:\n    An `EstimatorSpec` instance.\n\n  Raises:\n    ValueError: If features has the wrong type.\n  \"\"\"\n\n  optimizer = optimizers.get_optimizer_instance(\n      optimizer, learning_rate=_LEARNING_RATE)\n\n  _validate_features(features)\n\n  num_ps_replicas = config.num_ps_replicas if config else 0\n\n  partitioner = (None if use_tpu else tf.compat.v1.min_max_variable_partitioner(\n      max_partitions=num_ps_replicas))\n  with tf.compat.v1.variable_scope(\n      'dnn', values=tuple(six.itervalues(features)), partitioner=partitioner):\n    input_layer_partitioner = input_layer_partitioner or (\n        None if use_tpu else tf.compat.v1.min_max_variable_partitioner(\n            max_partitions=num_ps_replicas, min_slice_size=64 << 20))\n\n    logit_fn = dnn_logit_fn_builder(\n        units=head.logits_dimension,\n        hidden_units=hidden_units,\n        feature_columns=feature_columns,\n        activation_fn=activation_fn,\n        dropout=dropout,\n        input_layer_partitioner=input_layer_partitioner,\n        batch_norm=batch_norm)\n    logits = logit_fn(features=features, mode=mode)\n\n    return _get_dnn_estimator_spec(use_tpu, head, features, labels, mode,\n                                   logits, optimizer)\n\n\ndef _dnn_model_fn_builder_v2(units, hidden_units, feature_columns,\n                             activation_fn, dropout, batch_norm, features,\n                             mode):\n  \"\"\"Function builder for dnn logits, trainable variables and update ops.\n\n  Args:\n    units: An int indicating the dimension of the logit layer.  In the MultiHead\n      case, this should be the sum of all component Heads' logit dimensions.\n    hidden_units: Iterable of integer number of hidden units per layer.\n    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.\n    activation_fn: Activation function applied to each layer.\n    dropout: When not `None`, the probability we will drop out a given\n      coordinate.\n    batch_norm: Whether to use batch normalization after each hidden layer.\n    features: This is the first item returned from the `input_fn` passed to\n      `train`, `evaluate`, and `predict`. This should be a single `Tensor` or\n      `dict` of same.\n    mode: Optional. Specifies if this training, evaluation or prediction. See\n      `ModeKeys`.\n\n  Returns:\n    A `Tensor` representing the logits, or a list of `Tensor`'s representing\n      multiple logits in the MultiHead case.\n    A list of trainable variables.\n    A list of update ops.\n\n  Raises:\n    ValueError: If units is not an int.\n  \"\"\"\n  if not isinstance(units, six.integer_types):\n    raise ValueError('units must be an int.  Given type: {}'.format(\n        type(units)))\n  dnn_model = _DNNModelV2(\n      units,\n      hidden_units,\n      feature_columns,\n      activation_fn,\n      dropout,\n      batch_norm,\n      name='dnn')\n  logits = dnn_model(features, mode)\n  trainable_variables = dnn_model.trainable_variables\n  update_ops = dnn_model.updates\n\n  return logits, trainable_variables, update_ops\n\n\ndef dnn_model_fn_v2(features,\n                    labels,\n                    mode,\n                    head,\n                    hidden_units,\n                    feature_columns,\n                    optimizer='Adagrad',\n                    activation_fn=tf.nn.relu,\n                    dropout=None,\n                    config=None,\n                    use_tpu=False,\n                    batch_norm=False):\n  \"\"\"Deep Neural Net model_fn v2.\n\n  This function is different than _dnn_model_fn_v1 in the way it handles the\n  optimizer when a String optimizer name is passed.\n\n  Args:\n    features: dict of `Tensor`.\n    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype\n      `int32` or `int64` in the range `[0, n_classes)`.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `base_head.Head` instance.\n    hidden_units: Iterable of integer number of hidden units per layer.\n    feature_columns: Iterable of `feature_column._FeatureColumn` model inputs.\n    optimizer: String, `tf_keras.optimizers.Optimizer` object, or callable that\n      creates the optimizer to use for training. If not specified, will use the\n      Adagrad optimizer. If it is String, the default learning rate of the\n      optimizer will be used. If it is String, and optimizer does not have a\n      default learning rate, then, a fixed learning rate of 0.05 is used.\n    activation_fn: Activation function applied to each layer.\n    dropout: When not `None`, the probability we will drop out a given\n      coordinate.\n    config: `RunConfig` object to configure the runtime settings.\n    use_tpu: Whether to make a DNN model able to run on TPU. Will make function\n      return a `_TPUEstimatorSpec` instance and disable variable partitioning.\n    batch_norm: Whether to use batch normalization after each hidden layer.\n\n  Returns:\n    An `EstimatorSpec` instance.\n\n  Raises:\n    ValueError: If features has the wrong type.\n  \"\"\"\n  _validate_features(features)\n\n  del config\n\n  logits, trainable_variables, update_ops = _dnn_model_fn_builder_v2(\n      units=head.logits_dimension,\n      hidden_units=hidden_units,\n      feature_columns=feature_columns,\n      activation_fn=activation_fn,\n      dropout=dropout,\n      batch_norm=batch_norm,\n      features=features,\n      mode=mode)\n\n  # In TRAIN mode, create optimizer and assign global_step variable to\n  # optimizer.iterations to make global_step increased correctly, as Hooks\n  # relies on global step as step counter.\n  if mode == ModeKeys.TRAIN:\n    optimizer = optimizers.get_optimizer_instance_v2(optimizer)\n    optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()\n\n  # Create EstimatorSpec.\n  if use_tpu:\n    estimator_spec_fn = head._create_tpu_estimator_spec  # pylint: disable=protected-access\n  else:\n    estimator_spec_fn = head.create_estimator_spec  # pylint: disable=protected-access\n\n  return estimator_spec_fn(\n      features=features,\n      mode=mode,\n      labels=labels,\n      optimizer=optimizer,\n      logits=logits,\n      trainable_variables=trainable_variables,\n      update_ops=update_ops)\n\n\n@estimator_export('estimator.DNNClassifier', v1=[])\nclass DNNClassifierV2(estimator.EstimatorV2):\n  \"\"\"A classifier for TensorFlow DNN models.\n\n  Example:\n\n  ```python\n  categorical_feature_a = categorical_column_with_hash_bucket(...)\n  categorical_feature_b = categorical_column_with_hash_bucket(...)\n\n  categorical_feature_a_emb = embedding_column(\n      categorical_column=categorical_feature_a, ...)\n  categorical_feature_b_emb = embedding_column(\n      categorical_column=categorical_feature_b, ...)\n\n  estimator = tf.estimator.DNNClassifier(\n      feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],\n      hidden_units=[1024, 512, 256])\n\n  # Or estimator using the ProximalAdagradOptimizer optimizer with\n  # regularization.\n  estimator = tf.estimator.DNNClassifier(\n      feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      optimizer=tf.compat.v1.train.ProximalAdagradOptimizer(\n        learning_rate=0.1,\n        l1_regularization_strength=0.001\n      ))\n\n  # Or estimator using an optimizer with a learning rate decay.\n  estimator = tf.estimator.DNNClassifier(\n      feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      optimizer=lambda: tf_keras.optimizers.Adam(\n          learning_rate=tf.compat.v1.train.exponential_decay(\n              learning_rate=0.1,\n              global_step=tf.compat.v1.train.get_global_step(),\n              decay_steps=10000,\n              decay_rate=0.96))\n\n  # Or estimator with warm-starting from a previous checkpoint.\n  estimator = tf.estimator.DNNClassifier(\n      feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      warm_start_from=\"/path/to/checkpoint/dir\")\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_predict:\n    # Returns tf.data.Dataset of (x, None) tuple.\n    pass\n  estimator.train(input_fn=input_fn_train)\n  metrics = estimator.evaluate(input_fn=input_fn_eval)\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n  otherwise there will be a `KeyError`:\n\n  * if `weight_column` is not `None`, a feature with `key=weight_column` whose\n    value is a `Tensor`.\n  * for each `column` in `feature_columns`:\n    - if `column` is a `CategoricalColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedCategoricalColumn`, two features: the first\n      with `key` the id column name, the second with `key` the weight column\n      name. Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `DenseColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss is calculated by using softmax cross entropy.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(\n      self,\n      hidden_units,\n      feature_columns,\n      model_dir=None,\n      n_classes=2,\n      weight_column=None,\n      label_vocabulary=None,\n      optimizer='Adagrad',\n      activation_fn=tf.nn.relu,\n      dropout=None,\n      config=None,\n      warm_start_from=None,\n      loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n      batch_norm=False,\n  ):\n    \"\"\"Initializes a `DNNClassifier` instance.\n\n    Args:\n      hidden_units: Iterable of number hidden units per layer. All layers are\n        fully connected. Ex. `[64, 32]` means first layer has 64 nodes and\n        second one has 32.\n      feature_columns: An iterable containing all the feature columns used by\n        the model. All items in the set should be instances of classes derived\n        from `_FeatureColumn`.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      n_classes: Number of label classes. Defaults to 2, namely binary\n        classification. Must be > 1.\n      weight_column: A string or a `NumericColumn` created by\n        `tf.feature_column.numeric_column` defining feature column representing\n        weights. It is used to down weight or boost examples during training. It\n        will be multiplied by the loss of the example. If it is a string, it is\n        used as a key to fetch weight tensor from the `features`. If it is a\n        `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n        weight_column.normalizer_fn is applied on it to get weight tensor.\n      label_vocabulary: A list of strings represents possible label values. If\n        given, labels must be string type and have any value in\n        `label_vocabulary`. If it is not given, that means labels are already\n        encoded as integer or float within [0, 1] for `n_classes=2` and encoded\n        as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also\n        there will be errors if vocabulary is not provided and labels are\n        string.\n      optimizer: An instance of `tf_keras.optimizers.*` used to train the model.\n        Can also be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp',\n        SGD'), or callable. Defaults to Adagrad optimizer.\n      activation_fn: Activation function applied to each layer. If `None`, will\n        use `tf.nn.relu`.\n      dropout: When not `None`, the probability we will drop out a given\n        coordinate.\n      config: `RunConfig` object to configure the runtime settings.\n      warm_start_from: A string filepath to a checkpoint to warm-start from, or\n        a `WarmStartSettings` object to fully configure warm-starting.  If the\n        string filepath is provided instead of a `WarmStartSettings`, then all\n        weights are warm-started, and it is assumed that vocabularies and Tensor\n        names are unchanged.\n      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how\n        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n      batch_norm: Whether to use batch normalization after each hidden layer.\n    \"\"\"\n    head = head_utils.binary_or_multi_class_head(\n        n_classes,\n        weight_column=weight_column,\n        label_vocabulary=label_vocabulary,\n        loss_reduction=loss_reduction)\n    estimator._canned_estimator_api_gauge.get_cell('Classifier').set('DNN')\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared dnn_model_fn_v2.\"\"\"\n      return dnn_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          hidden_units=hidden_units,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          activation_fn=activation_fn,\n          dropout=dropout,\n          config=config,\n          batch_norm=batch_norm)\n\n    super(DNNClassifierV2, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export(v1=['estimator.DNNClassifier'])  # pylint: disable=missing-docstring\nclass DNNClassifier(estimator.Estimator):\n  __doc__ = DNNClassifierV2.__doc__.replace('SUM_OVER_BATCH_SIZE', 'SUM')\n\n  def __init__(\n      self,\n      hidden_units,\n      feature_columns,\n      model_dir=None,\n      n_classes=2,\n      weight_column=None,\n      label_vocabulary=None,\n      optimizer='Adagrad',\n      activation_fn=tf.nn.relu,\n      dropout=None,\n      input_layer_partitioner=None,\n      config=None,\n      warm_start_from=None,\n      loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n      batch_norm=False,\n  ):\n    head = head_lib._binary_logistic_or_multi_class_head(  # pylint: disable=protected-access\n        n_classes, weight_column, label_vocabulary, loss_reduction)\n    estimator._canned_estimator_api_gauge.get_cell('Classifier').set('DNN')\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared dnn_model_fn.\"\"\"\n      return _dnn_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          hidden_units=hidden_units,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          activation_fn=activation_fn,\n          dropout=dropout,\n          input_layer_partitioner=input_layer_partitioner,\n          config=config,\n          batch_norm=batch_norm)\n\n    super(DNNClassifier, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export('estimator.DNNEstimator', v1=[])\nclass DNNEstimatorV2(estimator.EstimatorV2):\n  \"\"\"An estimator for TensorFlow DNN models with user-specified head.\n\n  Example:\n\n  ```python\n  sparse_feature_a = sparse_column_with_hash_bucket(...)\n  sparse_feature_b = sparse_column_with_hash_bucket(...)\n\n  sparse_feature_a_emb = embedding_column(sparse_id_column=sparse_feature_a,\n                                          ...)\n  sparse_feature_b_emb = embedding_column(sparse_id_column=sparse_feature_b,\n                                          ...)\n\n  estimator = tf.estimator.DNNEstimator(\n      head=tf.estimator.MultiLabelHead(n_classes=3),\n      feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],\n      hidden_units=[1024, 512, 256])\n\n  # Or estimator using the ProximalAdagradOptimizer optimizer with\n  # regularization.\n  estimator = tf.estimator.DNNEstimator(\n      head=tf.estimator.MultiLabelHead(n_classes=3),\n      feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      optimizer=tf.compat.v1.train.ProximalAdagradOptimizer(\n        learning_rate=0.1,\n        l1_regularization_strength=0.001\n      ))\n\n  # Or estimator using an optimizer with a learning rate decay.\n  estimator = tf.estimator.DNNEstimator(\n      head=tf.estimator.MultiLabelHead(n_classes=3),\n      feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      optimizer=lambda: tf_keras.optimizers.Adam(\n          learning_rate=tf.compat.v1.train.exponential_decay(\n              learning_rate=0.1,\n              global_step=tf.compat.v1.train.get_global_step(),\n              decay_steps=10000,\n              decay_rate=0.96))\n\n  # Or estimator with warm-starting from a previous checkpoint.\n  estimator = tf.estimator.DNNEstimator(\n      head=tf.estimator.MultiLabelHead(n_classes=3),\n      feature_columns=[sparse_feature_a_emb, sparse_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      warm_start_from=\"/path/to/checkpoint/dir\")\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_predict:\n    # Returns tf.data.Dataset of (x, None) tuple.\n    pass\n  estimator.train(input_fn=input_fn_train)\n  metrics = estimator.evaluate(input_fn=input_fn_eval)\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n  otherwise there will be a `KeyError`:\n\n  * if `weight_column` is not `None`, a feature with `key=weight_column` whose\n    value is a `Tensor`.\n  * for each `column` in `feature_columns`:\n    - if `column` is a `CategoricalColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedCategoricalColumn`, two features: the first\n      with `key` the id column name, the second with `key` the weight column\n      name. Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `DenseColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss and predicted output are determined by the specified head.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               head,\n               hidden_units,\n               feature_columns,\n               model_dir=None,\n               optimizer='Adagrad',\n               activation_fn=tf.nn.relu,\n               dropout=None,\n               config=None,\n               warm_start_from=None,\n               batch_norm=False):\n    \"\"\"Initializes a `DNNEstimator` instance.\n\n    Args:\n      head: A `_Head` instance constructed with a method such as\n        `tf.contrib.estimator.multi_label_head`.\n      hidden_units: Iterable of number hidden units per layer. All layers are\n        fully connected. Ex. `[64, 32]` means first layer has 64 nodes and\n        second one has 32.\n      feature_columns: An iterable containing all the feature columns used by\n        the model. All items in the set should be instances of classes derived\n        from `_FeatureColumn`.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      optimizer: An instance of `tf_keras.optimizers.*` used to train the model.\n        Can also be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp',\n        SGD'), or callable. Defaults to Adagrad optimizer.\n      activation_fn: Activation function applied to each layer. If `None`, will\n        use `tf.nn.relu`.\n      dropout: When not `None`, the probability we will drop out a given\n        coordinate.\n      config: `RunConfig` object to configure the runtime settings.\n      warm_start_from: A string filepath to a checkpoint to warm-start from, or\n        a `WarmStartSettings` object to fully configure warm-starting.  If the\n        string filepath is provided instead of a `WarmStartSettings`, then all\n        weights are warm-started, and it is assumed that vocabularies and Tensor\n        names are unchanged.\n      batch_norm: Whether to use batch normalization after each hidden layer.\n    \"\"\"\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared dnn_model_fn_v2.\"\"\"\n      return dnn_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          hidden_units=hidden_units,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          activation_fn=activation_fn,\n          dropout=dropout,\n          config=config,\n          batch_norm=batch_norm)\n\n    estimator._canned_estimator_api_gauge.get_cell('Estimator').set('DNN')  # pylint: disable=protected-access\n    super(DNNEstimatorV2, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export(v1=['estimator.DNNEstimator'])  # pylint: disable=missing-docstring\nclass DNNEstimator(estimator.Estimator):\n  __doc__ = DNNEstimatorV2.__doc__\n\n  def __init__(self,\n               head,\n               hidden_units,\n               feature_columns,\n               model_dir=None,\n               optimizer='Adagrad',\n               activation_fn=tf.nn.relu,\n               dropout=None,\n               input_layer_partitioner=None,\n               config=None,\n               warm_start_from=None,\n               batch_norm=False):\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared _dnn_model_fn.\"\"\"\n      return _dnn_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          hidden_units=hidden_units,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          activation_fn=activation_fn,\n          dropout=dropout,\n          input_layer_partitioner=input_layer_partitioner,\n          config=config,\n          batch_norm=batch_norm)\n\n    estimator._canned_estimator_api_gauge.get_cell('Estimator').set('DNN')  # pylint: disable=protected-access\n    super(DNNEstimator, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export('estimator.DNNRegressor', v1=[])\nclass DNNRegressorV2(estimator.EstimatorV2):\n  \"\"\"A regressor for TensorFlow DNN models.\n\n  Example:\n\n  ```python\n  categorical_feature_a = categorical_column_with_hash_bucket(...)\n  categorical_feature_b = categorical_column_with_hash_bucket(...)\n\n  categorical_feature_a_emb = embedding_column(\n      categorical_column=categorical_feature_a, ...)\n  categorical_feature_b_emb = embedding_column(\n      categorical_column=categorical_feature_b, ...)\n\n  estimator = tf.estimator.DNNRegressor(\n      feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],\n      hidden_units=[1024, 512, 256])\n\n  # Or estimator using the ProximalAdagradOptimizer optimizer with\n  # regularization.\n  estimator = tf.estimator.DNNRegressor(\n      feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      optimizer=tf.compat.v1.train.ProximalAdagradOptimizer(\n        learning_rate=0.1,\n        l1_regularization_strength=0.001\n      ))\n\n  # Or estimator using an optimizer with a learning rate decay.\n  estimator = tf.estimator.DNNRegressor(\n      feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      optimizer=lambda: tf_keras.optimizers.Adam(\n          learning_rate=tf.compat.v1.train.exponential_decay(\n              learning_rate=0.1,\n              global_step=tf.compat.v1.train.get_global_step(),\n              decay_steps=10000,\n              decay_rate=0.96))\n\n  # Or estimator with warm-starting from a previous checkpoint.\n  estimator = tf.estimator.DNNRegressor(\n      feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      warm_start_from=\"/path/to/checkpoint/dir\")\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_predict:\n    # Returns tf.data.Dataset of (x, None) tuple.\n    pass\n  estimator.train(input_fn=input_fn_train)\n  metrics = estimator.evaluate(input_fn=input_fn_eval)\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n  otherwise there will be a `KeyError`:\n\n  * if `weight_column` is not `None`, a feature with `key=weight_column` whose\n    value is a `Tensor`.\n  * for each `column` in `feature_columns`:\n    - if `column` is a `CategoricalColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedCategoricalColumn`, two features: the first\n      with `key` the id column name, the second with `key` the weight column\n      name. Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `DenseColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss is calculated by using mean squared error.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(\n      self,\n      hidden_units,\n      feature_columns,\n      model_dir=None,\n      label_dimension=1,\n      weight_column=None,\n      optimizer='Adagrad',\n      activation_fn=tf.nn.relu,\n      dropout=None,\n      config=None,\n      warm_start_from=None,\n      loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n      batch_norm=False,\n  ):\n    \"\"\"Initializes a `DNNRegressor` instance.\n\n    Args:\n      hidden_units: Iterable of number hidden units per layer. All layers are\n        fully connected. Ex. `[64, 32]` means first layer has 64 nodes and\n        second one has 32.\n      feature_columns: An iterable containing all the feature columns used by\n        the model. All items in the set should be instances of classes derived\n        from `FeatureColumn`.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      label_dimension: Number of regression targets per example. This is the\n        size of the last dimension of the labels and logits `Tensor` objects\n        (typically, these have shape `[batch_size, label_dimension]`).\n      weight_column: A string or a `NumericColumn` created by\n        `tf.feature_column.numeric_column` defining feature column representing\n        weights. It is used to down weight or boost examples during training. It\n        will be multiplied by the loss of the example. If it is a string, it is\n        used as a key to fetch weight tensor from the `features`. If it is a\n        `NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n        weight_column.normalizer_fn is applied on it to get weight tensor.\n      optimizer: An instance of `tf_keras.optimizers.*` used to train the model.\n        Can also be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp',\n        SGD'), or callable. Defaults to Adagrad optimizer.\n      activation_fn: Activation function applied to each layer. If `None`, will\n        use `tf.nn.relu`.\n      dropout: When not `None`, the probability we will drop out a given\n        coordinate.\n      config: `RunConfig` object to configure the runtime settings.\n      warm_start_from: A string filepath to a checkpoint to warm-start from, or\n        a `WarmStartSettings` object to fully configure warm-starting.  If the\n        string filepath is provided instead of a `WarmStartSettings`, then all\n        weights are warm-started, and it is assumed that vocabularies and Tensor\n        names are unchanged.\n      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how\n        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n      batch_norm: Whether to use batch normalization after each hidden layer.\n    \"\"\"\n    head = regression_head.RegressionHead(\n        label_dimension=label_dimension,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction)\n    estimator._canned_estimator_api_gauge.get_cell('Regressor').set('DNN')  # pylint: disable=protected-access\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared dnn_model_fn_v2.\"\"\"\n      return dnn_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          hidden_units=hidden_units,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          activation_fn=activation_fn,\n          dropout=dropout,\n          config=config,\n          batch_norm=batch_norm)\n\n    super(DNNRegressorV2, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export(v1=['estimator.DNNRegressor'])  # pylint: disable=missing-docstring\nclass DNNRegressor(estimator.Estimator):\n  __doc__ = DNNRegressorV2.__doc__.replace('SUM_OVER_BATCH_SIZE', 'SUM')\n\n  def __init__(\n      self,\n      hidden_units,\n      feature_columns,\n      model_dir=None,\n      label_dimension=1,\n      weight_column=None,\n      optimizer='Adagrad',\n      activation_fn=tf.nn.relu,\n      dropout=None,\n      input_layer_partitioner=None,\n      config=None,\n      warm_start_from=None,\n      loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n      batch_norm=False,\n  ):\n    head = head_lib._regression_head(  # pylint: disable=protected-access\n        label_dimension=label_dimension,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction)\n    estimator._canned_estimator_api_gauge.get_cell('Regressor').set('DNN')  # pylint: disable=protected-access\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared _dnn_model_fn.\"\"\"\n      return _dnn_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          hidden_units=hidden_units,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          activation_fn=activation_fn,\n          dropout=dropout,\n          input_layer_partitioner=input_layer_partitioner,\n          config=config,\n          batch_norm=batch_norm)\n\n    super(DNNRegressor, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/dnn_estimator_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for DNNEstimatorV2.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import dnn_testing_utils\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.head import multi_class_head\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n\ndef _dnn_estimator_fn(weight_column=None, label_dimension=1, **kwargs):\n  \"\"\"Returns a DNNEstimator that uses regression_head.\"\"\"\n  return dnn.DNNEstimatorV2(\n      head=regression_head.RegressionHead(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          # Tests in core (from which this test inherits) test the sum loss.\n          loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE),\n      **kwargs)\n\n\ndef _dnn_estimator_classifier_fn(n_classes=3, **kwargs):\n  return dnn.DNNEstimatorV2(\n      head=multi_class_head.MultiClassHead(n_classes=n_classes), **kwargs)\n\n\nclass DNNLogitFnBuilderTest(tf.test.TestCase):\n\n  def testLongInPy2(self):\n    if six.PY2:\n      ret = dnn.dnn_logit_fn_builder(\n          long(1), None, None, None, None, None, None)\n      self.assertTrue(callable(ret))\n\n\nclass DNNEstimatorEvaluateTest(dnn_testing_utils.BaseDNNRegressorEvaluateTest,\n                               tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_estimator_fn)\n\n\nclass DNNEstimatorPredictTest(dnn_testing_utils.BaseDNNRegressorPredictTest,\n                              tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_estimator_fn)\n\n\nclass DNNEstimatorTrainTest(dnn_testing_utils.BaseDNNRegressorTrainTest,\n                            tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_estimator_fn)\n\n\nclass DNNEstimatorWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,\n                                   tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNWarmStartingTest.__init__(\n        self, _dnn_estimator_classifier_fn, _dnn_estimator_fn)\n\n\nclass DNNEstimatorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self,\n                          train_input_fn,\n                          eval_input_fn,\n                          predict_input_fn,\n                          input_dimension,\n                          label_dimension,\n                          batch_size,\n                          optimizer='Adagrad'):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = dnn.DNNEstimatorV2(\n        head=regression_head.RegressionHead(label_dimension=label_dimension),\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        optimizer=optimizer,\n        model_dir=self._model_dir)\n\n    # Train\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # Evaluate\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # Predict\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # Export\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _create_input_fn(self, label_dimension, batch_size):\n    \"\"\"Creates input_fn for integration test.\"\"\"\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    return train_input_fn, eval_input_fn, predict_input_fn\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    train_input_fn, eval_input_fn, predict_input_fn = self._create_input_fn(\n        label_dimension, batch_size)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n  def test_numpy_input_fn_with_optimizer_instance(self):\n    \"\"\"Tests complete flow with optimizer_v2 instance.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    train_input_fn, eval_input_fn, predict_input_fn = self._create_input_fn(\n        label_dimension, batch_size)\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        optimizer=tf_keras.optimizers.legacy.Adagrad(\n            0.01))  # Test with optimizer_v2 instance\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/dnn_linear_combined.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"TensorFlow estimators for Linear and DNN joined training models.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\n\nimport six\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import linear\nfrom tensorflow_estimator.python.estimator.canned import optimizers\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.head import head_utils\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n# The default learning rates are a historical artifact of the initial\n# implementation.\n_DNN_LEARNING_RATE = 0.001\n_LINEAR_LEARNING_RATE = 0.005\n\n\ndef _check_no_sync_replicas_optimizer(optimizer):\n  if isinstance(optimizer, tf.compat.v1.train.SyncReplicasOptimizer):\n    raise ValueError(\n        'SyncReplicasOptimizer does not support multi optimizers case. '\n        'Therefore, it is not supported in DNNLinearCombined model. '\n        'If you want to use this optimizer, please use either DNN or Linear '\n        'model.')\n\n\ndef _linear_learning_rate(num_linear_feature_columns):\n  \"\"\"Returns the default learning rate of the linear model.\n\n  The calculation is a historical artifact of this initial implementation, but\n  has proven a reasonable choice.\n\n  Args:\n    num_linear_feature_columns: The number of feature columns of the linear\n      model.\n\n  Returns:\n    A float.\n  \"\"\"\n  default_learning_rate = 1. / math.sqrt(num_linear_feature_columns)\n  return min(_LINEAR_LEARNING_RATE, default_learning_rate)\n\n\ndef _add_layer_summary(value, tag):\n  tf.compat.v1.summary.scalar('%s/fraction_of_zero_values' % tag,\n                              tf.math.zero_fraction(value))\n  tf.compat.v1.summary.histogram('%s/activation' % tag, value)\n\n\ndef _validate_feature_columns(linear_feature_columns, dnn_feature_columns):\n  \"\"\"Validates feature columns DNNLinearCombinedRegressor.\"\"\"\n  linear_feature_columns = linear_feature_columns or []\n  dnn_feature_columns = dnn_feature_columns or []\n  feature_columns = (list(linear_feature_columns) + list(dnn_feature_columns))\n  if not feature_columns:\n    raise ValueError('Either linear_feature_columns or dnn_feature_columns '\n                     'must be defined.')\n  return feature_columns\n\n\ndef _dnn_linear_combined_model_fn_v2(\n    features,\n    labels,\n    mode,\n    head,\n    linear_feature_columns=None,\n    linear_optimizer='Ftrl',\n    dnn_feature_columns=None,\n    dnn_optimizer='Adagrad',\n    dnn_hidden_units=None,\n    dnn_activation_fn=tf.nn.relu,\n    dnn_dropout=None,\n    config=None,\n    batch_norm=False,\n    linear_sparse_combiner='sum',\n    loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE):\n  \"\"\"Deep Neural Net and Linear combined model_fn.\n\n  Args:\n    features: dict of `Tensor`.\n    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype\n      `int32` or `int64` in the range `[0, n_classes)`.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `Head` instance.\n    linear_feature_columns: An iterable containing all the feature columns used\n      by the Linear model.\n    linear_optimizer: string, `Optimizer` object, or callable that defines the\n      optimizer to use for training the Linear model. Defaults to the Ftrl\n      optimizer.\n    dnn_feature_columns: An iterable containing all the feature columns used by\n      the DNN model.\n    dnn_optimizer: string, `Optimizer` object, or callable that defines the\n      optimizer to use for training the DNN model. Defaults to the Adagrad\n      optimizer.\n    dnn_hidden_units: List of hidden units per DNN layer.\n    dnn_activation_fn: Activation function applied to each DNN layer. If `None`,\n      will use `tf.nn.relu`.\n    dnn_dropout: When not `None`, the probability we will drop out a given DNN\n      coordinate.\n    config: `RunConfig` object to configure the runtime settings.\n    batch_norm: Whether to use batch normalization after each hidden layer.\n    linear_sparse_combiner: A string specifying how to reduce the linear model\n      if a categorical column is multivalent.  One of \"mean\", \"sqrtn\", and\n      \"sum\".\n    loss_reduction: One of `tf_keras.losses.Reduction` except `NONE`. Describes\n      how to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n\n  Returns:\n    An `EstimatorSpec` instance.\n\n  Raises:\n    ValueError: If both `linear_feature_columns` and `dnn_features_columns`\n      are empty at the same time, or `input_layer_partitioner` is missing,\n      or features has the wrong type.\n  \"\"\"\n  if not isinstance(features, dict):\n    raise ValueError('features should be a dictionary of `Tensor`s. '\n                     'Given type: {}'.format(type(features)))\n  if not linear_feature_columns and not dnn_feature_columns:\n    raise ValueError(\n        'Either linear_feature_columns or dnn_feature_columns must be defined.')\n\n  del config\n\n  # Build DNN Logits.\n  if not dnn_feature_columns:\n    dnn_logits = None\n  else:\n    if mode == ModeKeys.TRAIN:\n      dnn_optimizer = optimizers.get_optimizer_instance_v2(\n          dnn_optimizer, learning_rate=_DNN_LEARNING_RATE)\n      _check_no_sync_replicas_optimizer(dnn_optimizer)\n\n    if not dnn_hidden_units:\n      raise ValueError(\n          'dnn_hidden_units must be defined when dnn_feature_columns is '\n          'specified.')\n    dnn_logits, dnn_trainable_variables, dnn_update_ops = (\n        dnn._dnn_model_fn_builder_v2(  # pylint: disable=protected-access\n            units=head.logits_dimension,\n            hidden_units=dnn_hidden_units,\n            feature_columns=dnn_feature_columns,\n            activation_fn=dnn_activation_fn,\n            dropout=dnn_dropout,\n            batch_norm=batch_norm,\n            features=features,\n            mode=mode))\n\n  if not linear_feature_columns:\n    linear_logits = None\n  else:\n    if mode == ModeKeys.TRAIN:\n      linear_optimizer = optimizers.get_optimizer_instance_v2(\n          linear_optimizer,\n          learning_rate=_linear_learning_rate(len(linear_feature_columns)))\n      _check_no_sync_replicas_optimizer(linear_optimizer)\n\n    linear_logits, linear_trainable_variables = (\n        linear._linear_model_fn_builder_v2(  # pylint: disable=protected-access\n            units=head.logits_dimension,\n            feature_columns=linear_feature_columns,\n            sparse_combiner=linear_sparse_combiner,\n            features=features))\n    _add_layer_summary(linear_logits, 'linear')\n\n  # Combine logits and build full model.\n  if dnn_logits is not None and linear_logits is not None:\n    logits = dnn_logits + linear_logits\n  elif dnn_logits is not None:\n    logits = dnn_logits\n  else:\n    logits = linear_logits\n\n  def _train_op_fn(loss):\n    \"\"\"Returns the op to optimize the loss.\"\"\"\n    train_ops = []\n    # Scale loss by number of replicas.\n    if loss_reduction == tf.losses.Reduction.SUM_OVER_BATCH_SIZE:\n      num_replicas = tf.distribute.get_strategy().num_replicas_in_sync\n      if num_replicas > 1:\n        loss *= (1. / num_replicas)\n\n    if dnn_logits is not None:\n      train_ops.extend(dnn_optimizer.get_updates(loss, dnn_trainable_variables))\n      if dnn_update_ops is not None:\n        train_ops.extend(dnn_update_ops)\n    if linear_logits is not None:\n      train_ops.extend(\n          linear_optimizer.get_updates(loss, linear_trainable_variables))\n    train_op = tf.group(*train_ops)\n    return train_op\n\n  # In TRAIN mode, asssign global_step variable to optimizer.iterations to\n  # make global_step increased correctly, as Hooks relies on global step as\n  # step counter. Note that, Only one model's optimizer needs this assignment.\n  if mode == ModeKeys.TRAIN:\n    if dnn_logits is not None:\n      dnn_optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()\n    else:\n      linear_optimizer.iterations = \\\n        tf.compat.v1.train.get_or_create_global_step()\n\n  return head.create_estimator_spec(\n      features=features,\n      mode=mode,\n      labels=labels,\n      train_op_fn=_train_op_fn,\n      logits=logits)\n\n\ndef _dnn_linear_combined_model_fn(features,\n                                  labels,\n                                  mode,\n                                  head,\n                                  linear_feature_columns=None,\n                                  linear_optimizer='Ftrl',\n                                  dnn_feature_columns=None,\n                                  dnn_optimizer='Adagrad',\n                                  dnn_hidden_units=None,\n                                  dnn_activation_fn=tf.nn.relu,\n                                  dnn_dropout=None,\n                                  input_layer_partitioner=None,\n                                  config=None,\n                                  batch_norm=False,\n                                  linear_sparse_combiner='sum'):\n  \"\"\"Deep Neural Net and Linear combined model_fn.\n\n  Args:\n    features: dict of `Tensor`.\n    labels: `Tensor` of shape [batch_size, 1] or [batch_size] labels of dtype\n      `int32` or `int64` in the range `[0, n_classes)`.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `Head` instance.\n    linear_feature_columns: An iterable containing all the feature columns used\n      by the Linear model.\n    linear_optimizer: string, `Optimizer` object, or callable that defines the\n      optimizer to use for training the Linear model. Defaults to the Ftrl\n      optimizer.\n    dnn_feature_columns: An iterable containing all the feature columns used by\n      the DNN model.\n    dnn_optimizer: string, `Optimizer` object, or callable that defines the\n      optimizer to use for training the DNN model. Defaults to the Adagrad\n      optimizer.\n    dnn_hidden_units: List of hidden units per DNN layer.\n    dnn_activation_fn: Activation function applied to each DNN layer. If `None`,\n      will use `tf.nn.relu`.\n    dnn_dropout: When not `None`, the probability we will drop out a given DNN\n      coordinate.\n    input_layer_partitioner: Partitioner for input layer.\n    config: `RunConfig` object to configure the runtime settings.\n    batch_norm: Whether to use batch normalization after each hidden layer.\n    linear_sparse_combiner: A string specifying how to reduce the linear model\n      if a categorical column is multivalent.  One of \"mean\", \"sqrtn\", and\n      \"sum\".\n\n  Returns:\n    An `EstimatorSpec` instance.\n\n  Raises:\n    ValueError: If both `linear_feature_columns` and `dnn_features_columns`\n      are empty at the same time, or `input_layer_partitioner` is missing,\n      or features has the wrong type.\n  \"\"\"\n  if not isinstance(features, dict):\n    raise ValueError('features should be a dictionary of `Tensor`s. '\n                     'Given type: {}'.format(type(features)))\n  if not linear_feature_columns and not dnn_feature_columns:\n    raise ValueError(\n        'Either linear_feature_columns or dnn_feature_columns must be defined.')\n\n  num_ps_replicas = config.num_ps_replicas if config else 0\n  input_layer_partitioner = input_layer_partitioner or (\n      tf.compat.v1.min_max_variable_partitioner(\n          max_partitions=num_ps_replicas, min_slice_size=64 << 20))\n\n  # Build DNN Logits.\n  dnn_parent_scope = 'dnn'\n\n  if not dnn_feature_columns:\n    dnn_logits = None\n  else:\n    dnn_optimizer = optimizers.get_optimizer_instance(\n        dnn_optimizer, learning_rate=_DNN_LEARNING_RATE)\n    _check_no_sync_replicas_optimizer(dnn_optimizer)\n    if not dnn_hidden_units:\n      raise ValueError(\n          'dnn_hidden_units must be defined when dnn_feature_columns is '\n          'specified.')\n    dnn_partitioner = (\n        tf.compat.v1.min_max_variable_partitioner(\n            max_partitions=num_ps_replicas))\n    with tf.compat.v1.variable_scope(\n        dnn_parent_scope,\n        values=tuple(six.itervalues(features)),\n        partitioner=dnn_partitioner) as scope:\n      dnn_absolute_scope = scope.name\n      dnn_logit_fn = dnn.dnn_logit_fn_builder(\n          units=head.logits_dimension,\n          hidden_units=dnn_hidden_units,\n          feature_columns=dnn_feature_columns,\n          activation_fn=dnn_activation_fn,\n          dropout=dnn_dropout,\n          batch_norm=batch_norm,\n          input_layer_partitioner=input_layer_partitioner)\n      dnn_logits = dnn_logit_fn(features=features, mode=mode)\n\n  linear_parent_scope = 'linear'\n\n  if not linear_feature_columns:\n    linear_logits = None\n  else:\n    linear_optimizer = optimizers.get_optimizer_instance(\n        linear_optimizer,\n        learning_rate=_linear_learning_rate(len(linear_feature_columns)))\n    _check_no_sync_replicas_optimizer(linear_optimizer)\n    with tf.compat.v1.variable_scope(\n        linear_parent_scope,\n        values=tuple(six.itervalues(features)),\n        partitioner=input_layer_partitioner) as scope:\n      linear_absolute_scope = scope.name\n      logit_fn = linear.linear_logit_fn_builder(\n          units=head.logits_dimension,\n          feature_columns=linear_feature_columns,\n          sparse_combiner=linear_sparse_combiner)\n      linear_logits = logit_fn(features=features)\n      _add_layer_summary(linear_logits, scope.name)\n\n  # Combine logits and build full model.\n  if dnn_logits is not None and linear_logits is not None:\n    logits = dnn_logits + linear_logits\n  elif dnn_logits is not None:\n    logits = dnn_logits\n  else:\n    logits = linear_logits\n\n  def _train_op_fn(loss):\n    \"\"\"Returns the op to optimize the loss.\"\"\"\n    train_ops = []\n    global_step = tf.compat.v1.train.get_global_step()\n    if dnn_logits is not None:\n      train_ops.append(\n          dnn_optimizer.minimize(\n              loss,\n              var_list=tf.compat.v1.get_collection(\n                  tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES,\n                  scope=dnn_absolute_scope)))\n    if linear_logits is not None:\n      train_ops.append(\n          linear_optimizer.minimize(\n              loss,\n              var_list=tf.compat.v1.get_collection(\n                  tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES,\n                  scope=linear_absolute_scope)))\n\n    train_op = tf.group(*train_ops)\n    with tf.control_dependencies([train_op]):\n      return tf.compat.v1.assign_add(global_step, 1).op\n\n  return head.create_estimator_spec(\n      features=features,\n      mode=mode,\n      labels=labels,\n      train_op_fn=_train_op_fn,\n      logits=logits)\n\n\n@estimator_export('estimator.DNNLinearCombinedClassifier', v1=[])\nclass DNNLinearCombinedClassifierV2(estimator.EstimatorV2):\n  \"\"\"An estimator for TensorFlow Linear and DNN joined classification models.\n\n  Note: This estimator is also known as wide-n-deep.\n\n  Example:\n\n  ```python\n  numeric_feature = numeric_column(...)\n  categorical_column_a = categorical_column_with_hash_bucket(...)\n  categorical_column_b = categorical_column_with_hash_bucket(...)\n\n  categorical_feature_a_x_categorical_feature_b = crossed_column(...)\n  categorical_feature_a_emb = embedding_column(\n      categorical_column=categorical_feature_a, ...)\n  categorical_feature_b_emb = embedding_column(\n      categorical_id_column=categorical_feature_b, ...)\n\n  estimator = tf.estimator.DNNLinearCombinedClassifier(\n      # wide settings\n      linear_feature_columns=[categorical_feature_a_x_categorical_feature_b],\n      linear_optimizer=tf_keras.optimizers.Ftrl(...),\n      # deep settings\n      dnn_feature_columns=[\n          categorical_feature_a_emb, categorical_feature_b_emb,\n          numeric_feature],\n      dnn_hidden_units=[1000, 500, 100],\n      dnn_optimizer=tf_keras.optimizers.Adagrad(...),\n      # warm-start settings\n      warm_start_from=\"/path/to/checkpoint/dir\")\n\n  # To apply L1 and L2 regularization, you can set dnn_optimizer to:\n  tf.compat.v1.train.ProximalAdagradOptimizer(\n      learning_rate=0.1,\n      l1_regularization_strength=0.001,\n      l2_regularization_strength=0.001)\n  # To apply learning rate decay, you can set dnn_optimizer to a callable:\n  lambda: tf_keras.optimizers.Adam(\n      learning_rate=tf.compat.v1.train.exponential_decay(\n          learning_rate=0.1,\n          global_step=tf.compat.v1.train.get_global_step(),\n          decay_steps=10000,\n          decay_rate=0.96)\n  # It is the same for linear_optimizer.\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_predict:\n    # Returns tf.data.Dataset of (x, None) tuple.\n    pass\n  estimator.train(input_fn=input_fn_train, steps=100)\n  metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n  otherwise there will be a `KeyError`:\n\n  * for each `column` in `dnn_feature_columns` + `linear_feature_columns`:\n    - if `column` is a `CategoricalColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedCategoricalColumn`, two features: the first\n      with `key` the id column name, the second with `key` the weight column\n      name. Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `DenseColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss is calculated by using softmax cross entropy.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               model_dir=None,\n               linear_feature_columns=None,\n               linear_optimizer='Ftrl',\n               dnn_feature_columns=None,\n               dnn_optimizer='Adagrad',\n               dnn_hidden_units=None,\n               dnn_activation_fn=tf.nn.relu,\n               dnn_dropout=None,\n               n_classes=2,\n               weight_column=None,\n               label_vocabulary=None,\n               config=None,\n               warm_start_from=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               batch_norm=False,\n               linear_sparse_combiner='sum'):\n    \"\"\"Initializes a DNNLinearCombinedClassifier instance.\n\n    Args:\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      linear_feature_columns: An iterable containing all the feature columns\n        used by linear part of the model. All items in the set must be instances\n        of classes derived from `FeatureColumn`.\n      linear_optimizer: An instance of `tf_keras.optimizers.*` used to apply\n        gradients to the linear part of the model. Can also be a string (one of\n        'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to\n        FTRL optimizer.\n      dnn_feature_columns: An iterable containing all the feature columns used\n        by deep part of the model. All items in the set must be instances of\n        classes derived from `FeatureColumn`.\n      dnn_optimizer: An instance of `tf_keras.optimizers.*` used to apply\n        gradients to the deep part of the model. Can also be a string (one of\n        'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to\n        Adagrad optimizer.\n      dnn_hidden_units: List of hidden units per layer. All layers are fully\n        connected.\n      dnn_activation_fn: Activation function applied to each layer. If None,\n        will use `tf.nn.relu`.\n      dnn_dropout: When not None, the probability we will drop out a given\n        coordinate.\n      n_classes: Number of label classes. Defaults to 2, namely binary\n        classification. Must be > 1.\n      weight_column: A string or a `_NumericColumn` created by\n        `tf.feature_column.numeric_column` defining feature column representing\n        weights. It is used to down weight or boost examples during training. It\n        will be multiplied by the loss of the example. If it is a string, it is\n        used as a key to fetch weight tensor from the `features`. If it is a\n        `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n        weight_column.normalizer_fn is applied on it to get weight tensor.\n      label_vocabulary: A list of strings represents possible label values. If\n        given, labels must be string type and have any value in\n        `label_vocabulary`. If it is not given, that means labels are already\n        encoded as integer or float within [0, 1] for `n_classes=2` and encoded\n        as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also\n        there will be errors if vocabulary is not provided and labels are\n        string.\n      config: RunConfig object to configure the runtime settings.\n      warm_start_from: A string filepath to a checkpoint to warm-start from, or\n        a `WarmStartSettings` object to fully configure warm-starting.  If the\n        string filepath is provided instead of a `WarmStartSettings`, then all\n        weights are warm-started, and it is assumed that vocabularies and Tensor\n        names are unchanged.\n      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how\n        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n      batch_norm: Whether to use batch normalization after each hidden layer.\n      linear_sparse_combiner: A string specifying how to reduce the linear model\n        if a categorical column is multivalent.  One of \"mean\", \"sqrtn\", and\n        \"sum\" -- these are effectively different ways to do example-level\n        normalization, which can be useful for bag-of-words features.  For more\n        details, see `tf.feature_column.linear_model`.\n\n    Raises:\n      ValueError: If both linear_feature_columns and dnn_features_columns are\n        empty at the same time.\n    \"\"\"\n    self._feature_columns = _validate_feature_columns(\n        linear_feature_columns=linear_feature_columns,\n        dnn_feature_columns=dnn_feature_columns)\n\n    head = head_utils.binary_or_multi_class_head(\n        n_classes,\n        weight_column=weight_column,\n        label_vocabulary=label_vocabulary,\n        loss_reduction=loss_reduction)\n    estimator._canned_estimator_api_gauge.get_cell('Classifier').set(  # pylint: disable=protected-access\n        'DNNLinearCombined')\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the _dnn_linear_combined_model_fn.\"\"\"\n      return _dnn_linear_combined_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          linear_feature_columns=linear_feature_columns,\n          linear_optimizer=linear_optimizer,\n          dnn_feature_columns=dnn_feature_columns,\n          dnn_optimizer=dnn_optimizer,\n          dnn_hidden_units=dnn_hidden_units,\n          dnn_activation_fn=dnn_activation_fn,\n          dnn_dropout=dnn_dropout,\n          config=config,\n          batch_norm=batch_norm,\n          linear_sparse_combiner=linear_sparse_combiner,\n          loss_reduction=loss_reduction)\n\n    super(DNNLinearCombinedClassifierV2, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export(v1=['estimator.DNNLinearCombinedClassifier'])  # pylint: disable=missing-docstring\nclass DNNLinearCombinedClassifier(estimator.Estimator):\n  __doc__ = DNNLinearCombinedClassifierV2.__doc__.replace(\n      'SUM_OVER_BATCH_SIZE', 'SUM')\n\n  def __init__(self,\n               model_dir=None,\n               linear_feature_columns=None,\n               linear_optimizer='Ftrl',\n               dnn_feature_columns=None,\n               dnn_optimizer='Adagrad',\n               dnn_hidden_units=None,\n               dnn_activation_fn=tf.nn.relu,\n               dnn_dropout=None,\n               n_classes=2,\n               weight_column=None,\n               label_vocabulary=None,\n               input_layer_partitioner=None,\n               config=None,\n               warm_start_from=None,\n               loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n               batch_norm=False,\n               linear_sparse_combiner='sum'):\n    self._feature_columns = _validate_feature_columns(\n        linear_feature_columns=linear_feature_columns,\n        dnn_feature_columns=dnn_feature_columns)\n\n    head = head_lib._binary_logistic_or_multi_class_head(  # pylint: disable=protected-access\n        n_classes, weight_column, label_vocabulary, loss_reduction)\n    estimator._canned_estimator_api_gauge.get_cell('Classifier').set(\n        'DNNLinearCombined')  # pylint: disable=protected-access\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the _dnn_linear_combined_model_fn.\"\"\"\n      return _dnn_linear_combined_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          linear_feature_columns=linear_feature_columns,\n          linear_optimizer=linear_optimizer,\n          dnn_feature_columns=dnn_feature_columns,\n          dnn_optimizer=dnn_optimizer,\n          dnn_hidden_units=dnn_hidden_units,\n          dnn_activation_fn=dnn_activation_fn,\n          dnn_dropout=dnn_dropout,\n          input_layer_partitioner=input_layer_partitioner,\n          config=config,\n          batch_norm=batch_norm,\n          linear_sparse_combiner=linear_sparse_combiner)\n\n    super(DNNLinearCombinedClassifier, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\ndef _init_dnn_linear_combined_estimator(head, linear_feature_columns,\n                                        linear_optimizer, dnn_feature_columns,\n                                        dnn_optimizer, dnn_hidden_units,\n                                        dnn_activation_fn, dnn_dropout,\n                                        input_layer_partitioner,\n                                        linear_sparse_combiner):\n  \"\"\"Helper function for the initialization of DNNLinearCombinedEstimator.\"\"\"\n  linear_feature_columns = linear_feature_columns or []\n  dnn_feature_columns = dnn_feature_columns or []\n  feature_columns = (list(linear_feature_columns) + list(dnn_feature_columns))\n  if not feature_columns:\n    raise ValueError('Either linear_feature_columns or dnn_feature_columns '\n                     'must be defined.')\n\n  def _model_fn(features, labels, mode, config):\n    \"\"\"Call the _dnn_linear_combined_model_fn.\"\"\"\n    return _dnn_linear_combined_model_fn(\n        features=features,\n        labels=labels,\n        mode=mode,\n        head=head,\n        linear_feature_columns=linear_feature_columns,\n        linear_optimizer=linear_optimizer,\n        dnn_feature_columns=dnn_feature_columns,\n        dnn_optimizer=dnn_optimizer,\n        dnn_hidden_units=dnn_hidden_units,\n        dnn_activation_fn=dnn_activation_fn,\n        dnn_dropout=dnn_dropout,\n        input_layer_partitioner=input_layer_partitioner,\n        config=config,\n        linear_sparse_combiner=linear_sparse_combiner)\n\n  return feature_columns, _model_fn\n\n\n@estimator_export('estimator.DNNLinearCombinedEstimator', v1=[])\nclass DNNLinearCombinedEstimatorV2(estimator.EstimatorV2):\n  \"\"\"An estimator for TensorFlow Linear and DNN joined models with custom head.\n\n  Note: This estimator is also known as wide-n-deep.\n\n  Example:\n\n  ```python\n  numeric_feature = numeric_column(...)\n  categorical_column_a = categorical_column_with_hash_bucket(...)\n  categorical_column_b = categorical_column_with_hash_bucket(...)\n\n  categorical_feature_a_x_categorical_feature_b = crossed_column(...)\n  categorical_feature_a_emb = embedding_column(\n      categorical_column=categorical_feature_a, ...)\n  categorical_feature_b_emb = embedding_column(\n      categorical_column=categorical_feature_b, ...)\n\n  estimator = tf.estimator.DNNLinearCombinedEstimator(\n      head=tf.estimator.MultiLabelHead(n_classes=3),\n      # wide settings\n      linear_feature_columns=[categorical_feature_a_x_categorical_feature_b],\n      linear_optimizer=tf_keras.optimizers.Ftrl(...),\n      # deep settings\n      dnn_feature_columns=[\n          categorical_feature_a_emb, categorical_feature_b_emb,\n          numeric_feature],\n      dnn_hidden_units=[1000, 500, 100],\n      dnn_optimizer=tf_keras.optimizers.Adagrad(...))\n\n  # To apply L1 and L2 regularization, you can set dnn_optimizer to:\n  tf.compat.v1.train.ProximalAdagradOptimizer(\n      learning_rate=0.1,\n      l1_regularization_strength=0.001,\n      l2_regularization_strength=0.001)\n  # To apply learning rate decay, you can set dnn_optimizer to a callable:\n  lambda: tf_keras.optimizers.Adam(\n      learning_rate=tf.compat.v1.train.exponential_decay(\n          learning_rate=0.1,\n          global_step=tf.compat.v1.train.get_global_step(),\n          decay_steps=10000,\n          decay_rate=0.96)\n  # It is the same for linear_optimizer.\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_predict:\n    # Returns tf.data.Dataset of (x, None) tuple.\n    pass\n  estimator.train(input_fn=input_fn_train, steps=100)\n  metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n  otherwise there will be a `KeyError`:\n\n  * for each `column` in `dnn_feature_columns` + `linear_feature_columns`:\n    - if `column` is a `CategoricalColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedCategoricalColumn`, two features: the first\n      with `key` the id column name, the second with `key` the weight column\n      name. Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `DenseColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss is calculated by using mean squared error.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               head,\n               model_dir=None,\n               linear_feature_columns=None,\n               linear_optimizer='Ftrl',\n               dnn_feature_columns=None,\n               dnn_optimizer='Adagrad',\n               dnn_hidden_units=None,\n               dnn_activation_fn=tf.nn.relu,\n               dnn_dropout=None,\n               config=None,\n               batch_norm=False,\n               linear_sparse_combiner='sum'):\n    \"\"\"Initializes a DNNLinearCombinedEstimator instance.\n\n    Args:\n      head: A `Head` instance constructed with a method such as\n        `tf.estimator.MultiLabelHead`.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into an estimator to\n        continue training a previously saved model.\n      linear_feature_columns: An iterable containing all the feature columns\n        used by linear part of the model. All items in the set must be instances\n        of classes derived from `FeatureColumn`.\n      linear_optimizer: An instance of `tf_keras.optimizers.*` used to apply\n        gradients to the linear part of the model. Can also be a string (one of\n        'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to\n        FTRL optimizer.\n      dnn_feature_columns: An iterable containing all the feature columns used\n        by deep part of the model. All items in the set must be instances of\n        classes derived from `FeatureColumn`.\n      dnn_optimizer: An instance of `tf_keras.optimizers.*` used to apply\n        gradients to the deep part of the model. Can also be a string (one of\n        'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to\n        Adagrad optimizer.\n      dnn_hidden_units: List of hidden units per layer. All layers are fully\n        connected.\n      dnn_activation_fn: Activation function applied to each layer. If None,\n        will use `tf.nn.relu`.\n      dnn_dropout: When not None, the probability we will drop out a given\n        coordinate.\n      config: RunConfig object to configure the runtime settings.\n      batch_norm: Whether to use batch normalization after each hidden layer.\n      linear_sparse_combiner: A string specifying how to reduce the linear model\n        if a categorical column is multivalent.  One of \"mean\", \"sqrtn\", and\n        \"sum\" -- these are effectively different ways to do example-level\n        normalization, which can be useful for bag-of-words features.  For more\n        details, see `tf.feature_column.linear_model`.\n\n    Raises:\n      ValueError: If both linear_feature_columns and dnn_features_columns are\n        empty at the same time.\n    \"\"\"\n    self._feature_columns = _validate_feature_columns(\n        linear_feature_columns=linear_feature_columns,\n        dnn_feature_columns=dnn_feature_columns)\n    estimator._canned_estimator_api_gauge.get_cell('Estimator').set(\n        'DNNLinearCombined')  # pylint: disable=protected-access\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the _dnn_linear_combined_model_fn.\"\"\"\n      return _dnn_linear_combined_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          linear_feature_columns=linear_feature_columns,\n          linear_optimizer=linear_optimizer,\n          dnn_feature_columns=dnn_feature_columns,\n          dnn_optimizer=dnn_optimizer,\n          dnn_hidden_units=dnn_hidden_units,\n          dnn_activation_fn=dnn_activation_fn,\n          dnn_dropout=dnn_dropout,\n          config=config,\n          batch_norm=batch_norm,\n          linear_sparse_combiner=linear_sparse_combiner)\n\n    super(DNNLinearCombinedEstimatorV2, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config)\n\n\n@estimator_export(v1=['estimator.DNNLinearCombinedEstimator'])  # pylint: disable=missing-docstring\nclass DNNLinearCombinedEstimator(estimator.Estimator):\n  __doc__ = DNNLinearCombinedEstimatorV2.__doc__\n\n  def __init__(self,\n               head,\n               model_dir=None,\n               linear_feature_columns=None,\n               linear_optimizer='Ftrl',\n               dnn_feature_columns=None,\n               dnn_optimizer='Adagrad',\n               dnn_hidden_units=None,\n               dnn_activation_fn=tf.nn.relu,\n               dnn_dropout=None,\n               input_layer_partitioner=None,\n               config=None,\n               batch_norm=False,\n               linear_sparse_combiner='sum'):\n    self._feature_columns = _validate_feature_columns(\n        linear_feature_columns=linear_feature_columns,\n        dnn_feature_columns=dnn_feature_columns)\n    estimator._canned_estimator_api_gauge.get_cell('Estimator').set(\n        'DNNLinearCombined')  # pylint: disable=protected-access\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the _dnn_linear_combined_model_fn.\"\"\"\n      return _dnn_linear_combined_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          linear_feature_columns=linear_feature_columns,\n          linear_optimizer=linear_optimizer,\n          dnn_feature_columns=dnn_feature_columns,\n          dnn_optimizer=dnn_optimizer,\n          dnn_hidden_units=dnn_hidden_units,\n          dnn_activation_fn=dnn_activation_fn,\n          dnn_dropout=dnn_dropout,\n          input_layer_partitioner=input_layer_partitioner,\n          config=config,\n          batch_norm=batch_norm,\n          linear_sparse_combiner=linear_sparse_combiner)\n\n    super(DNNLinearCombinedEstimator, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config)\n\n\n@estimator_export('estimator.DNNLinearCombinedRegressor', v1=[])\nclass DNNLinearCombinedRegressorV2(estimator.EstimatorV2):\n  \"\"\"An estimator for TensorFlow Linear and DNN joined models for regression.\n\n  Note: This estimator is also known as wide-n-deep.\n\n  Example:\n\n  ```python\n  numeric_feature = numeric_column(...)\n  categorical_column_a = categorical_column_with_hash_bucket(...)\n  categorical_column_b = categorical_column_with_hash_bucket(...)\n\n  categorical_feature_a_x_categorical_feature_b = crossed_column(...)\n  categorical_feature_a_emb = embedding_column(\n      categorical_column=categorical_feature_a, ...)\n  categorical_feature_b_emb = embedding_column(\n      categorical_column=categorical_feature_b, ...)\n\n  estimator = tf.estimator.DNNLinearCombinedRegressor(\n      # wide settings\n      linear_feature_columns=[categorical_feature_a_x_categorical_feature_b],\n      linear_optimizer=tf_keras.optimizers.Ftrl(...),\n      # deep settings\n      dnn_feature_columns=[\n          categorical_feature_a_emb, categorical_feature_b_emb,\n          numeric_feature],\n      dnn_hidden_units=[1000, 500, 100],\n      dnn_optimizer=tf_keras.optimizers.Adagrad(...),\n      # warm-start settings\n      warm_start_from=\"/path/to/checkpoint/dir\")\n\n  # To apply L1 and L2 regularization, you can set dnn_optimizer to:\n  tf.compat.v1.train.ProximalAdagradOptimizer(\n      learning_rate=0.1,\n      l1_regularization_strength=0.001,\n      l2_regularization_strength=0.001)\n  # To apply learning rate decay, you can set dnn_optimizer to a callable:\n  lambda: tf_keras.optimizers.Adam(\n      learning_rate=tf.compat.v1.train.exponential_decay(\n          learning_rate=0.1,\n          global_step=tf.compat.v1.train.get_global_step(),\n          decay_steps=10000,\n          decay_rate=0.96)\n  # It is the same for linear_optimizer.\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_predict:\n    # Returns tf.data.Dataset of (x, None) tuple.\n    pass\n  estimator.train(input_fn=input_fn_train, steps=100)\n  metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n  otherwise there will be a `KeyError`:\n\n  * for each `column` in `dnn_feature_columns` + `linear_feature_columns`:\n    - if `column` is a `CategoricalColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedCategoricalColumn`, two features: the first\n      with `key` the id column name, the second with `key` the weight column\n      name. Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `DenseColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss is calculated by using mean squared error.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               model_dir=None,\n               linear_feature_columns=None,\n               linear_optimizer='Ftrl',\n               dnn_feature_columns=None,\n               dnn_optimizer='Adagrad',\n               dnn_hidden_units=None,\n               dnn_activation_fn=tf.nn.relu,\n               dnn_dropout=None,\n               label_dimension=1,\n               weight_column=None,\n               config=None,\n               warm_start_from=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               batch_norm=False,\n               linear_sparse_combiner='sum'):\n    \"\"\"Initializes a DNNLinearCombinedRegressor instance.\n\n    Args:\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      linear_feature_columns: An iterable containing all the feature columns\n        used by linear part of the model. All items in the set must be instances\n        of classes derived from `FeatureColumn`.\n      linear_optimizer: An instance of `tf_keras.optimizers.*` used to apply\n        gradients to the linear part of the model. Can also be a string (one of\n        'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to\n        FTRL optimizer.\n      dnn_feature_columns: An iterable containing all the feature columns used\n        by deep part of the model. All items in the set must be instances of\n        classes derived from `FeatureColumn`.\n      dnn_optimizer: An instance of `tf_keras.optimizers.*` used to apply\n        gradients to the deep part of the model. Can also be a string (one of\n        'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or callable. Defaults to\n        Adagrad optimizer.\n      dnn_hidden_units: List of hidden units per layer. All layers are fully\n        connected.\n      dnn_activation_fn: Activation function applied to each layer. If None,\n        will use `tf.nn.relu`.\n      dnn_dropout: When not None, the probability we will drop out a given\n        coordinate.\n      label_dimension: Number of regression targets per example. This is the\n        size of the last dimension of the labels and logits `Tensor` objects\n        (typically, these have shape `[batch_size, label_dimension]`).\n      weight_column: A string or a `NumericColumn` created by\n        `tf.feature_column.numeric_column` defining feature column representing\n        weights. It is used to down weight or boost examples during training. It\n        will be multiplied by the loss of the example. If it is a string, it is\n        used as a key to fetch weight tensor from the `features`. If it is a\n        `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n        weight_column.normalizer_fn is applied on it to get weight tensor.\n      config: RunConfig object to configure the runtime settings.\n      warm_start_from: A string filepath to a checkpoint to warm-start from, or\n        a `WarmStartSettings` object to fully configure warm-starting.  If the\n        string filepath is provided instead of a `WarmStartSettings`, then all\n        weights are warm-started, and it is assumed that vocabularies and Tensor\n        names are unchanged.\n      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how\n        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n      batch_norm: Whether to use batch normalization after each hidden layer.\n      linear_sparse_combiner: A string specifying how to reduce the linear model\n        if a categorical column is multivalent.  One of \"mean\", \"sqrtn\", and\n        \"sum\" -- these are effectively different ways to do example-level\n        normalization, which can be useful for bag-of-words features.  For more\n        details, see `tf.feature_column.linear_model`.\n\n    Raises:\n      ValueError: If both linear_feature_columns and dnn_features_columns are\n        empty at the same time.\n    \"\"\"\n    self._feature_columns = _validate_feature_columns(\n        linear_feature_columns=linear_feature_columns,\n        dnn_feature_columns=dnn_feature_columns)\n\n    head = regression_head.RegressionHead(\n        label_dimension=label_dimension,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction)\n    estimator._canned_estimator_api_gauge.get_cell('Regressor').set(\n        'DNNLinearCombined')  # pylint: disable=protected-access\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the _dnn_linear_combined_model_fn.\"\"\"\n      return _dnn_linear_combined_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          linear_feature_columns=linear_feature_columns,\n          linear_optimizer=linear_optimizer,\n          dnn_feature_columns=dnn_feature_columns,\n          dnn_optimizer=dnn_optimizer,\n          dnn_hidden_units=dnn_hidden_units,\n          dnn_activation_fn=dnn_activation_fn,\n          dnn_dropout=dnn_dropout,\n          config=config,\n          batch_norm=batch_norm,\n          linear_sparse_combiner=linear_sparse_combiner)\n\n    super(DNNLinearCombinedRegressorV2, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export(v1=['estimator.DNNLinearCombinedRegressor'])  # pylint: disable=missing-docstring\nclass DNNLinearCombinedRegressor(estimator.Estimator):\n  __doc__ = DNNLinearCombinedRegressorV2.__doc__.replace(\n      'SUM_OVER_BATCH_SIZE', 'SUM')\n\n  def __init__(self,\n               model_dir=None,\n               linear_feature_columns=None,\n               linear_optimizer='Ftrl',\n               dnn_feature_columns=None,\n               dnn_optimizer='Adagrad',\n               dnn_hidden_units=None,\n               dnn_activation_fn=tf.nn.relu,\n               dnn_dropout=None,\n               label_dimension=1,\n               weight_column=None,\n               input_layer_partitioner=None,\n               config=None,\n               warm_start_from=None,\n               loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n               batch_norm=False,\n               linear_sparse_combiner='sum'):\n    self._feature_columns = _validate_feature_columns(\n        linear_feature_columns=linear_feature_columns,\n        dnn_feature_columns=dnn_feature_columns)\n    estimator._canned_estimator_api_gauge.get_cell('Regressor').set(\n        'DNNLinearCombined')  # pylint: disable=protected-access\n\n    head = head_lib._regression_head(  # pylint: disable=protected-access\n        label_dimension=label_dimension,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction)\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the _dnn_linear_combined_model_fn.\"\"\"\n      return _dnn_linear_combined_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          linear_feature_columns=linear_feature_columns,\n          linear_optimizer=linear_optimizer,\n          dnn_feature_columns=dnn_feature_columns,\n          dnn_optimizer=dnn_optimizer,\n          dnn_hidden_units=dnn_hidden_units,\n          dnn_activation_fn=dnn_activation_fn,\n          dnn_dropout=dnn_dropout,\n          input_layer_partitioner=input_layer_partitioner,\n          config=config,\n          batch_norm=batch_norm,\n          linear_sparse_combiner=linear_sparse_combiner)\n\n    super(DNNLinearCombinedRegressor, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/dnn_linear_combined_estimator_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for DNNLinearCombinedEstimatorV2.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import dnn_linear_combined\nfrom tensorflow_estimator.python.estimator.canned import dnn_testing_utils\nfrom tensorflow_estimator.python.estimator.canned import linear_testing_utils\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n\ndef _dnn_only_estimator_fn(hidden_units,\n                           feature_columns,\n                           model_dir=None,\n                           label_dimension=1,\n                           weight_column=None,\n                           optimizer='Adagrad',\n                           activation_fn=tf.nn.relu,\n                           dropout=None,\n                           config=None):\n  return dnn_linear_combined.DNNLinearCombinedEstimatorV2(\n      head=regression_head.RegressionHead(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          # Tests in core (from which this test inherits) test the sum loss.\n          loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE),\n      model_dir=model_dir,\n      dnn_feature_columns=feature_columns,\n      dnn_optimizer=optimizer,\n      dnn_hidden_units=hidden_units,\n      dnn_activation_fn=activation_fn,\n      dnn_dropout=dropout,\n      config=config)\n\n\nclass DNNOnlyEstimatorEvaluateTest(\n    dnn_testing_utils.BaseDNNRegressorEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_only_estimator_fn)\n\n\nclass DNNOnlyEstimatorPredictTest(dnn_testing_utils.BaseDNNRegressorPredictTest,\n                                  tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_only_estimator_fn)\n\n\nclass DNNOnlyEstimatorTrainTest(dnn_testing_utils.BaseDNNRegressorTrainTest,\n                                tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_only_estimator_fn)\n\n\ndef _linear_only_estimator_fn(feature_columns,\n                              model_dir=None,\n                              label_dimension=1,\n                              weight_column=None,\n                              optimizer='Ftrl',\n                              config=None,\n                              sparse_combiner='sum'):\n  return dnn_linear_combined.DNNLinearCombinedEstimatorV2(\n      head=regression_head.RegressionHead(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          # Tests in core (from which this test inherits) test the sum loss.\n          loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE),\n      model_dir=model_dir,\n      linear_feature_columns=feature_columns,\n      linear_optimizer=optimizer,\n      config=config,\n      linear_sparse_combiner=sparse_combiner)\n\n\nclass LinearOnlyEstimatorEvaluateTest(\n    linear_testing_utils.BaseLinearRegressorEvaluationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_only_estimator_fn)\n\n\nclass LinearOnlyEstimatorPredictTest(\n    linear_testing_utils.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_only_estimator_fn)\n\n\nclass LinearOnlyEstimatorTrainTest(\n    linear_testing_utils.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_only_estimator_fn)\n\n\nclass DNNLinearCombinedEstimatorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self,\n                          train_input_fn,\n                          eval_input_fn,\n                          predict_input_fn,\n                          input_dimension,\n                          label_dimension,\n                          batch_size,\n                          dnn_optimizer='Adagrad',\n                          linear_optimizer='Ftrl'):\n    linear_feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    est = dnn_linear_combined.DNNLinearCombinedEstimatorV2(\n        head=regression_head.RegressionHead(label_dimension=label_dimension),\n        linear_feature_columns=linear_feature_columns,\n        dnn_feature_columns=dnn_feature_columns,\n        dnn_hidden_units=(2, 2),\n        model_dir=self._model_dir,\n        dnn_optimizer=dnn_optimizer,\n        linear_optimizer=linear_optimizer)\n\n    # Train\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # Evaluate\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # Predict\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # Export\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _create_input_fn(self, label_dimension, batch_size):\n    \"\"\"Creates input_fn for integration test.\"\"\"\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    return train_input_fn, eval_input_fn, predict_input_fn\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    train_input_fn, eval_input_fn, predict_input_fn = self._create_input_fn(\n        label_dimension, batch_size)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n  def test_numpy_input_fn_with_optimizer_instance(self):\n    \"\"\"Tests complete flow with optimizer_v2 instance.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    train_input_fn, eval_input_fn, predict_input_fn = self._create_input_fn(\n        label_dimension, batch_size)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        dnn_optimizer=tf_keras.optimizers.legacy.Adagrad(0.01),\n        linear_optimizer=tf_keras.optimizers.legacy.Ftrl(0.01))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/dnn_linear_combined_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for v2 version of dnn_linear_combined.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import dnn_linear_combined\nfrom tensorflow_estimator.python.estimator.canned import dnn_testing_utils\nfrom tensorflow_estimator.python.estimator.canned import linear_testing_utils\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n\nclass DNNOnlyModelFnTest(dnn_testing_utils.BaseDNNModelFnTest,\n                         tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNModelFnTest.__init__(self, self._dnn_only_model_fn)\n\n  def _dnn_only_model_fn(self,\n                         features,\n                         labels,\n                         mode,\n                         head,\n                         hidden_units,\n                         feature_columns,\n                         optimizer='Adagrad',\n                         activation_fn=tf.nn.relu,\n                         dropout=None,\n                         config=None):\n    return dnn_linear_combined._dnn_linear_combined_model_fn_v2(\n        features=features,\n        labels=labels,\n        mode=mode,\n        head=head,\n        linear_feature_columns=[],\n        dnn_hidden_units=hidden_units,\n        dnn_feature_columns=feature_columns,\n        dnn_optimizer=optimizer,\n        dnn_activation_fn=activation_fn,\n        dnn_dropout=dropout,\n        config=config)\n\n\n# A function to mimic linear-regressor init reuse same tests.\ndef _linear_regressor_fn(feature_columns,\n                         model_dir=None,\n                         label_dimension=1,\n                         weight_column=None,\n                         optimizer='Ftrl',\n                         config=None,\n                         sparse_combiner='sum'):\n  return dnn_linear_combined.DNNLinearCombinedRegressorV2(\n      model_dir=model_dir,\n      linear_feature_columns=feature_columns,\n      linear_optimizer=optimizer,\n      label_dimension=label_dimension,\n      weight_column=weight_column,\n      config=config,\n      linear_sparse_combiner=sparse_combiner)\n\n\nclass LinearOnlyRegressorEvaluationV2Test(\n    linear_testing_utils.BaseLinearRegressorEvaluationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\nclass LinearOnlyRegressorPredictV2Test(\n    linear_testing_utils.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\nclass LinearOnlyRegressorIntegrationV2Test(\n    linear_testing_utils.BaseLinearRegressorIntegrationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\nclass LinearOnlyRegressorTrainingV2Test(\n    linear_testing_utils.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\ndef _linear_classifier_fn(feature_columns,\n                          model_dir=None,\n                          n_classes=2,\n                          weight_column=None,\n                          label_vocabulary=None,\n                          optimizer='Ftrl',\n                          config=None,\n                          sparse_combiner='sum'):\n  return dnn_linear_combined.DNNLinearCombinedClassifierV2(\n      model_dir=model_dir,\n      linear_feature_columns=feature_columns,\n      linear_optimizer=optimizer,\n      n_classes=n_classes,\n      weight_column=weight_column,\n      label_vocabulary=label_vocabulary,\n      config=config,\n      linear_sparse_combiner=sparse_combiner)\n\n\nclass LinearOnlyClassifierTrainingV2Test(\n    linear_testing_utils.BaseLinearClassifierTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\nclass LinearOnlyClassifierClassesEvaluationV2Test(\n    linear_testing_utils.BaseLinearClassifierEvaluationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\nclass LinearOnlyClassifierPredictV2Test(\n    linear_testing_utils.BaseLinearClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearClassifierPredictTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\nclass LinearOnlyClassifierIntegrationV2Test(\n    linear_testing_utils.BaseLinearClassifierIntegrationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n@parameterized.parameters((feature_column_v2,))\nclass DNNLinearCombinedRegressorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow_helper(self, linear_feature_columns,\n                                 dnn_feature_columns, feature_spec,\n                                 train_input_fn, eval_input_fn,\n                                 predict_input_fn, input_dimension,\n                                 label_dimension, batch_size):\n    est = dnn_linear_combined.DNNLinearCombinedRegressorV2(\n        linear_feature_columns=linear_feature_columns,\n        dnn_hidden_units=(2, 2),\n        dnn_feature_columns=dnn_feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # EXPORT\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, batch_size,\n                          fc_impl):\n    linear_feature_columns = [\n        fc_impl.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        fc_impl.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    self._test_complete_flow_helper(linear_feature_columns, dnn_feature_columns,\n                                    feature_spec, train_input_fn, eval_input_fn,\n                                    predict_input_fn, input_dimension,\n                                    label_dimension, batch_size)\n\n  def _test_complete_flow_dnn_fc_v1(self, train_input_fn, eval_input_fn,\n                                    predict_input_fn, input_dimension,\n                                    label_dimension, batch_size, fc_impl):\n    del fc_impl\n    linear_feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        feature_column._numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    self._test_complete_flow_helper(linear_feature_columns, dnn_feature_columns,\n                                    feature_spec, train_input_fn, eval_input_fn,\n                                    predict_input_fn, input_dimension,\n                                    label_dimension, batch_size)\n\n  def _test_complete_flow_linear_fc_v1(self, train_input_fn, eval_input_fn,\n                                       predict_input_fn, input_dimension,\n                                       label_dimension, batch_size, fc_impl):\n    del fc_impl\n    linear_feature_columns = [\n        feature_column._numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    self._test_complete_flow_helper(linear_feature_columns, dnn_feature_columns,\n                                    feature_spec, train_input_fn, eval_input_fn,\n                                    predict_input_fn, input_dimension,\n                                    label_dimension, batch_size)\n\n  def _test_numpy_input_fn_helper(self, fc_impl, fn_to_run):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    fn_to_run(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_numpy_input_fn_basic(self, fc_impl):\n    self._test_numpy_input_fn_helper(fc_impl, self._test_complete_flow)\n\n  def test_numpy_input_fn_dnn_fc_v1(self, fc_impl):\n    with self.assertRaisesRegexp(\n        ValueError, r'Received a feature column from TensorFlow v1'):\n      self._test_numpy_input_fn_helper(fc_impl,\n                                       self._test_complete_flow_dnn_fc_v1)\n\n  def test_numpy_input_fn_linear_fc_v1(self, fc_impl):\n    with self.assertRaisesRegexp(\n        ValueError, r'Received a feature column from TensorFlow v1'):\n      self._test_numpy_input_fn_helper(fc_impl,\n                                       self._test_complete_flow_linear_fc_v1)\n\n  def _test_pandas_input_fn_helper(self, fc_impl, fn_to_run):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    label_dimension = 1\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(data)\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    fn_to_run(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_pandas_input_fn_basic(self, fc_impl):\n    self._test_pandas_input_fn_helper(fc_impl, self._test_complete_flow)\n\n  def test_pandas_input_fn_dnn_fc_v1(self, fc_impl):\n    with self.assertRaisesRegexp(\n        ValueError, r'Received a feature column from TensorFlow v1'):\n      self._test_pandas_input_fn_helper(fc_impl,\n                                        self._test_complete_flow_dnn_fc_v1)\n\n  def test_pandas_input_fn_linear_fc_v1(self, fc_impl):\n    with self.assertRaisesRegexp(\n        ValueError, r'Received a feature column from TensorFlow v1'):\n      self._test_pandas_input_fn_helper(fc_impl,\n                                        self._test_complete_flow_linear_fc_v1)\n\n  def _test_input_fn_from_parse_example_helper(self, fc_impl, fn_to_run):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = linear_testing_utils.queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = linear_testing_utils.queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = linear_testing_utils.queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    fn_to_run(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_input_fn_from_parse_example_basic(self, fc_impl):\n    self._test_input_fn_from_parse_example_helper(fc_impl,\n                                                  self._test_complete_flow)\n\n  def test_input_fn_from_parse_example_dnn_fc_v1(self, fc_impl):\n    with self.assertRaisesRegexp(\n        ValueError, r'Received a feature column from TensorFlow v1'):\n      self._test_input_fn_from_parse_example_helper(\n          fc_impl, self._test_complete_flow_dnn_fc_v1)\n\n  def test_input_fn_from_parse_example_linear_fc_v1(self, fc_impl):\n    with self.assertRaisesRegexp(\n        ValueError, r'Received a feature column from TensorFlow v1'):\n      self._test_input_fn_from_parse_example_helper(\n          fc_impl, self._test_complete_flow_linear_fc_v1)\n\n\n# A function to mimic dnn-classifier init reuse same tests.\ndef _dnn_classifier_fn(hidden_units,\n                       feature_columns,\n                       model_dir=None,\n                       n_classes=2,\n                       weight_column=None,\n                       label_vocabulary=None,\n                       optimizer='Adagrad',\n                       config=None):\n  return dnn_linear_combined.DNNLinearCombinedClassifierV2(\n      model_dir=model_dir,\n      dnn_hidden_units=hidden_units,\n      dnn_feature_columns=feature_columns,\n      dnn_optimizer=optimizer,\n      n_classes=n_classes,\n      weight_column=weight_column,\n      label_vocabulary=label_vocabulary,\n      config=config)\n\n\nclass DNNOnlyClassifierEvaluateV2Test(\n    dnn_testing_utils.BaseDNNClassifierEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\nclass DNNOnlyClassifierPredictV2Test(\n    dnn_testing_utils.BaseDNNClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\nclass DNNOnlyClassifierTrainV2Test(dnn_testing_utils.BaseDNNClassifierTrainTest,\n                                   tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\n# A function to mimic dnn-regressor init reuse same tests.\ndef _dnn_regressor_fn(hidden_units,\n                      feature_columns,\n                      model_dir=None,\n                      label_dimension=1,\n                      weight_column=None,\n                      optimizer='Adagrad',\n                      config=None):\n  return dnn_linear_combined.DNNLinearCombinedRegressorV2(\n      model_dir=model_dir,\n      dnn_hidden_units=hidden_units,\n      dnn_feature_columns=feature_columns,\n      dnn_optimizer=optimizer,\n      label_dimension=label_dimension,\n      weight_column=weight_column,\n      config=config)\n\n\nclass DNNOnlyRegressorEvaluateV2Test(\n    dnn_testing_utils.BaseDNNRegressorEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\nclass DNNOnlyRegressorPredictV2Test(\n    dnn_testing_utils.BaseDNNRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\nclass DNNOnlyRegressorTrainV2Test(dnn_testing_utils.BaseDNNRegressorTrainTest,\n                                  tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\n@parameterized.parameters((feature_column_v2,))\nclass DNNLinearCombinedClassifierIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _as_label(self, data_in_float):\n    return np.rint(data_in_float).astype(np.int64)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, n_classes, batch_size, fc_impl):\n    linear_feature_columns = [\n        fc_impl.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        fc_impl.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    est = dnn_linear_combined.DNNLinearCombinedClassifierV2(\n        linear_feature_columns=linear_feature_columns,\n        dnn_hidden_units=(2, 2),\n        dnn_feature_columns=dnn_feature_columns,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predicted_proba = np.array([\n        x[prediction_keys.PredictionKeys.PROBABILITIES]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self, fc_impl):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    n_classes = 3\n    input_dimension = 2\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    x_data = data.reshape(batch_size, input_dimension)\n    y_data = self._as_label(np.reshape(data[:batch_size], (batch_size, 1)))\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data},\n        y=y_data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_pandas_input_fn(self, fc_impl):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    input_dimension = 1\n    n_classes = 2\n    batch_size = 10\n    data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(self._as_label(data))\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_input_fn_from_parse_example(self, fc_impl):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    input_dimension = 2\n    n_classes = 3\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(\n                              value=self._as_label(datum[:1]))),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([1], tf.dtypes.int64),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = linear_testing_utils.queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = linear_testing_utils.queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = linear_testing_utils.queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n\n@parameterized.parameters((feature_column_v2,))\nclass DNNLinearCombinedTests(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def test_train_op_calls_both_dnn_and_linear(self, fc_impl):\n    dnn_opt = tf_keras.optimizers.legacy.SGD(1.)\n    linear_opt = tf_keras.optimizers.legacy.SGD(1.)\n    x_column = fc_impl.numeric_column('x')\n    input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[0.], [1.]])},\n        y=np.array([[0.], [1.]]),\n        batch_size=1,\n        shuffle=False)\n    est = dnn_linear_combined.DNNLinearCombinedClassifierV2(\n        linear_feature_columns=[x_column],\n        # verifies linear_optimizer is used only for linear part.\n        linear_optimizer=linear_opt,\n        dnn_hidden_units=(2, 2),\n        dnn_feature_columns=[x_column],\n        # verifies dnn_optimizer is used only for dnn part.\n        dnn_optimizer=dnn_opt,\n        model_dir=self._model_dir)\n    num_steps = 1\n    est.train(input_fn, steps=num_steps)\n    # verifies train_op fires linear minimize op\n    self.assertEqual(num_steps,\n                     est.get_variable_value(linear_opt.iterations.name))\n    # verifies train_op fires dnn optmizer\n    self.assertEqual(num_steps, est.get_variable_value(dnn_opt.iterations.name))\n\n  def test_dnn_and_linear_logits_are_added(self, fc_impl):\n    with tf.Graph().as_default():\n      tf.Variable([[1.0]], name='linear/linear_model/x/weights')\n      tf.Variable([2.0], name='linear/linear_model/bias_weights')\n      tf.Variable([[3.0]], name='dnn/hiddenlayer_0/kernel')\n      tf.Variable([4.0], name='dnn/hiddenlayer_0/bias')\n      tf.Variable([[5.0]], name='dnn/logits/kernel')\n      tf.Variable([6.0], name='dnn/logits/bias')\n      tf.Variable(1, name='global_step', dtype=tf.dtypes.int64)\n      linear_testing_utils.save_variables_to_ckpt(self._model_dir)\n\n    x_column = fc_impl.numeric_column('x')\n    est = dnn_linear_combined.DNNLinearCombinedRegressorV2(\n        linear_feature_columns=[x_column],\n        dnn_hidden_units=[1],\n        dnn_feature_columns=[x_column],\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)\n    # linear logits = 10*1 + 2 = 12\n    # dnn logits = (10*3 + 4)*5 + 6 = 176\n    # logits = dnn + linear = 176 + 12 = 188\n    self.assertAllClose({\n        prediction_keys.PredictionKeys.PREDICTIONS: [188.],\n    }, next(est.predict(input_fn=input_fn)))\n\n\n@parameterized.parameters((feature_column_v2,))\nclass DNNLinearCombinedWarmStartingTest(tf.test.TestCase):\n\n  def setUp(self):\n    # Create a directory to save our old checkpoint and vocabularies to.\n    self._ckpt_and_vocab_dir = tempfile.mkdtemp()\n\n    # Make a dummy input_fn.\n    def _input_fn():\n      features = {\n          'age': [[23.], [31.]],\n          'city': [['Palo Alto'], ['Mountain View']],\n      }\n      return features, [0, 1]\n\n    self._input_fn = _input_fn\n\n  def tearDown(self):\n    # Clean up checkpoint / vocab dir.\n    tf.compat.v1.summary.FileWriterCache.clear()\n    shutil.rmtree(self._ckpt_and_vocab_dir)\n\n  def test_classifier_basic_warm_starting(self, fc_impl):\n    \"\"\"Tests correctness of DNNLinearCombinedClassifier default warm-start.\"\"\"\n    age = fc_impl.numeric_column('age')\n    city = fc_impl.embedding_column(\n        fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNLinearCombinedClassifier and train to save a checkpoint.\n    dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifierV2(\n        linear_feature_columns=[age],\n        dnn_feature_columns=[city],\n        dnn_hidden_units=[256, 128],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        linear_optimizer='SGD',\n        dnn_optimizer='SGD')\n    dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNLinearCombinedClassifier, warm-started from the first.\n    # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't\n    # have accumulator values that change).\n    # To avoid optimizer naming issue during warm start, when to create the\n    # optimizer instance, the dnn_optimizer needs to be created first\n    # before the linear_optimizer, since this is the order pre-defined\n    # in the model function.\n    # Create a default graph context to make sure the optimizer instance is\n    # created within Graph v1 to make it consistent with estimator Graph.\n    with tf.Graph().as_default():\n      warm_started_dnn_lc_classifier = (\n          dnn_linear_combined.DNNLinearCombinedClassifierV2(\n              linear_feature_columns=[age],\n              dnn_feature_columns=[city],\n              dnn_hidden_units=[256, 128],\n              n_classes=4,\n              dnn_optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n              linear_optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n              warm_start_from=dnn_lc_classifier.model_dir))\n\n    warm_started_dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_lc_classifier.get_variable_names():\n      if 'learning_rate' in variable_name:\n        self.assertAllClose(\n            0.0,\n            warm_started_dnn_lc_classifier.get_variable_value(variable_name))\n      else:\n        self.assertAllClose(\n            dnn_lc_classifier.get_variable_value(variable_name),\n            warm_started_dnn_lc_classifier.get_variable_value(variable_name))\n\n  def test_regressor_basic_warm_starting(self, fc_impl):\n    \"\"\"Tests correctness of DNNLinearCombinedRegressor default warm-start.\"\"\"\n    age = fc_impl.numeric_column('age')\n    city = fc_impl.embedding_column(\n        fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNLinearCombinedRegressor and train to save a checkpoint.\n    dnn_lc_regressor = dnn_linear_combined.DNNLinearCombinedRegressorV2(\n        linear_feature_columns=[age],\n        dnn_feature_columns=[city],\n        dnn_hidden_units=[256, 128],\n        model_dir=self._ckpt_and_vocab_dir,\n        linear_optimizer='SGD',\n        dnn_optimizer='SGD')\n    dnn_lc_regressor.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNLinearCombinedRegressor, warm-started from the first.\n    # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't\n    # have accumulator values that change).\n    # To avoid optimizer naming issue during warm start, when to create the\n    # optimizer instance, the dnn_optimizer needs to be created first\n    # before the linear_optimizer, since this is the order pre-defined\n    # in the model function.\n    # Create a default graph context to make sure the optimizer instance is\n    # created within Graph v1 to make it consistent with estimator Graph.\n    with tf.Graph().as_default():\n      warm_started_dnn_lc_regressor = (\n          dnn_linear_combined.DNNLinearCombinedRegressorV2(\n              linear_feature_columns=[age],\n              dnn_feature_columns=[city],\n              dnn_hidden_units=[256, 128],\n              dnn_optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n              linear_optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n              warm_start_from=dnn_lc_regressor.model_dir))\n\n    warm_started_dnn_lc_regressor.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_lc_regressor.get_variable_names():\n      if 'learning_rate' in variable_name:\n        self.assertAllClose(\n            0.0,\n            warm_started_dnn_lc_regressor.get_variable_value(variable_name))\n      else:\n        self.assertAllClose(\n            dnn_lc_regressor.get_variable_value(variable_name),\n            warm_started_dnn_lc_regressor.get_variable_value(variable_name))\n\n  def test_warm_starting_selective_variables(self, fc_impl):\n    \"\"\"Tests selecting variables to warm-start.\"\"\"\n    age = fc_impl.numeric_column('age')\n    city = fc_impl.embedding_column(\n        fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNLinearCombinedClassifier and train to save a checkpoint.\n    dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifierV2(\n        linear_feature_columns=[age],\n        dnn_feature_columns=[city],\n        dnn_hidden_units=[256, 128],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        linear_optimizer='SGD',\n        dnn_optimizer='SGD')\n    dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNLinearCombinedClassifier, warm-started from the first.\n    # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't\n    # have accumulator values that change).\n    warm_started_dnn_lc_classifier = (\n        dnn_linear_combined.DNNLinearCombinedClassifierV2(\n            linear_feature_columns=[age],\n            dnn_feature_columns=[city],\n            dnn_hidden_units=[256, 128],\n            n_classes=4,\n            linear_optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n            dnn_optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n            # The provided regular expression will only warm-start the deep\n            # portion of the model.\n            warm_start_from=estimator.WarmStartSettings(\n                ckpt_to_initialize_from=dnn_lc_classifier.model_dir,\n                vars_to_warm_start='.*(dnn).*')))\n\n    warm_started_dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_lc_classifier.get_variable_names():\n      if 'dnn' in variable_name:\n        if 'learning_rate' in variable_name:\n          self.assertAllClose(\n              0.0,\n              warm_started_dnn_lc_classifier.get_variable_value(variable_name))\n        else:\n          self.assertAllClose(\n              dnn_lc_classifier.get_variable_value(variable_name),\n              warm_started_dnn_lc_classifier.get_variable_value(variable_name))\n      elif 'linear' in variable_name:\n        linear_values = warm_started_dnn_lc_classifier.get_variable_value(\n            variable_name)\n        # Since they're not warm-started, the linear weights will be\n        # zero-initialized.\n        self.assertAllClose(np.zeros_like(linear_values), linear_values)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/dnn_test_fc_v2.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for dnn.py with feature_column_v2.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nfrom unittest.mock import patch\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import dnn_testing_utils\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n\ndef _dnn_classifier_fn(*args, **kwargs):\n  return dnn.DNNClassifierV2(*args, **kwargs)\n\n\nclass DNNModelFnV2Test(dnn_testing_utils.BaseDNNModelFnTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNModelFnTest.__init__(\n        self, dnn.dnn_model_fn_v2, fc_impl=feature_column_v2)\n\n\nclass DNNLogitFnV2Test(dnn_testing_utils.BaseDNNLogitFnTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNLogitFnTest.__init__(\n        self, dnn.dnn_logit_fn_builder_v2, fc_impl=feature_column_v2)\n\n\nclass DNNWarmStartingV2Test(dnn_testing_utils.BaseDNNWarmStartingTest,\n                            tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNWarmStartingTest.__init__(\n        self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\nclass DNNClassifierEvaluateV2Test(\n    dnn_testing_utils.BaseDNNClassifierEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\nclass DNNClassifierPredictV2Test(dnn_testing_utils.BaseDNNClassifierPredictTest,\n                                 tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\nclass DNNClassifierTrainV2Test(dnn_testing_utils.BaseDNNClassifierTrainTest,\n                               tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\ndef _dnn_regressor_fn(*args, **kwargs):\n  return dnn.DNNRegressorV2(*args, **kwargs)\n\n\nclass DNNRegressorEvaluateV2Test(dnn_testing_utils.BaseDNNRegressorEvaluateTest,\n                                 tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\nclass DNNRegressorPredictV2Test(dnn_testing_utils.BaseDNNRegressorPredictTest,\n                                tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\nclass DNNRegressorTrainV2Test(dnn_testing_utils.BaseDNNRegressorTrainTest,\n                              tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\ndef _queue_parsed_features(feature_map):\n  tensors_to_enqueue = []\n  keys = []\n  for key, tensor in six.iteritems(feature_map):\n    keys.append(key)\n    tensors_to_enqueue.append(tensor)\n  queue_dtypes = [x.dtype for x in tensors_to_enqueue]\n  input_queue = tf.queue.FIFOQueue(capacity=100, dtypes=queue_dtypes)\n  tf.compat.v1.train.queue_runner.add_queue_runner(\n      tf.compat.v1.train.queue_runner.QueueRunner(\n          input_queue, [input_queue.enqueue(tensors_to_enqueue)]))\n  dequeued_tensors = input_queue.dequeue()\n  return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}\n\n\nclass DNNRegressorIntegrationTest(tf.test.TestCase, parameterized.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, batch_size):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n\n    est = dnn.DNNRegressorV2(\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUATE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    label_dimension = 1\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(data)\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n\nclass DNNClassifierIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _as_label(self, data_in_float):\n    return np.rint(data_in_float).astype(np.int64)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, n_classes, batch_size):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n\n    est = dnn.DNNClassifierV2(\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUATE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predicted_proba = np.array([\n        x[prediction_keys.PredictionKeys.PROBABILITIES]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    n_classes = 3\n    input_dimension = 2\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    x_data = data.reshape(batch_size, input_dimension)\n    y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data},\n        y=y_data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    input_dimension = 1\n    n_classes = 3\n    batch_size = 10\n    data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(self._as_label(data))\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    input_dimension = 2\n    n_classes = 3\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(\n                              value=self._as_label(datum[:1]))),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([1], tf.dtypes.int64),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n\nclass DNNTrainingMode(tf.test.TestCase):\n  \"\"\"Tests that training mode propagates to feature columns correctly.\"\"\"\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n    self._label_dimension = 1\n    self._batch_size = 10\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _create_data(self):\n    data = np.linspace(\n        0., 2., self._batch_size * self._label_dimension, dtype=np.float32)\n    return data.reshape(self._batch_size, self._label_dimension)\n\n  def _get_estimator(self):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(self._label_dimension,))\n    ]\n    return dnn.DNNRegressorV2(\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        label_dimension=self._label_dimension,\n        model_dir=self._model_dir)\n\n  def test_train_vs_eval_mode(self):\n    data = self._create_data()\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=self._batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=self._batch_size, shuffle=False)\n    est = self._get_estimator()\n    with patch.object(\n        tf_keras.layers.DenseFeatures, 'call',\n        return_value=data) as mock_dense_features_call:\n      est.train(train_input_fn, steps=10)\n      est.evaluate(eval_input_fn)\n    train_args, eval_args = mock_dense_features_call.call_args_list\n    # DenseFeature should have been called with training = True in train.\n    _, train_training_kwarg = train_args\n    self.assertTrue(train_training_kwarg['training'])\n    # DenseFeature should have been called with training = False in eval.\n    _, eval_training_kwarg = eval_args\n    self.assertFalse(eval_training_kwarg['training'])\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/dnn_testing_utils.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utils to be used in testing DNN estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.head import base_head\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n# pylint rules which are disabled by default for test files.\n# pylint: disable=invalid-name,protected-access,missing-docstring\n\n# Names of variables created by model.\nLEARNING_RATE_NAME = 'dnn/regression_head/dnn/learning_rate'\nHIDDEN_WEIGHTS_NAME_PATTERN = 'dnn/hiddenlayer_%d/kernel'\nHIDDEN_BIASES_NAME_PATTERN = 'dnn/hiddenlayer_%d/bias'\nBATCH_NORM_BETA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/beta'\nBATCH_NORM_GAMMA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/gamma'\nBATCH_NORM_MEAN_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/moving_mean'\nBATCH_NORM_VARIANCE_NAME_PATTERN = (\n    'dnn/hiddenlayer_%d/batchnorm_%d/moving_variance')\nLOGITS_WEIGHTS_NAME = 'dnn/logits/kernel'\nLOGITS_BIASES_NAME = 'dnn/logits/bias'\nOCCUPATION_EMBEDDING_NAME = ('dnn/input_from_feature_columns/input_layer/'\n                             'occupation_embedding/embedding_weights')\nCITY_EMBEDDING_NAME = ('dnn/input_from_feature_columns/input_layer/'\n                       'city_embedding/embedding_weights')\n\n\ndef assert_close(expected, actual, rtol=1e-04, message='', name='assert_close'):\n  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:\n    expected = ops.convert_to_tensor(expected, name='expected')\n    actual = ops.convert_to_tensor(actual, name='actual')\n    rdiff = tf.math.abs((expected - actual) / expected, 'diff')\n    rtol = ops.convert_to_tensor(rtol, name='rtol')\n    return tf.compat.v1.debugging.assert_less(\n        rdiff,\n        rtol,\n        data=('Condition expected =~ actual did not hold element-wise:'\n              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,\n              'rtol = ', rtol,),\n        summarize=expected.get_shape().num_elements(),\n        name=scope)\n\n\ndef create_checkpoint(weights_and_biases,\n                      global_step,\n                      model_dir,\n                      batch_norm_vars=None):\n  \"\"\"Create checkpoint file with provided model weights.\n\n  Args:\n    weights_and_biases: Iterable of tuples of weight and bias values.\n    global_step: Initial global step to save in checkpoint.\n    model_dir: Directory into which checkpoint is saved.\n    batch_norm_vars: Variables used for batch normalization.\n  \"\"\"\n  weights, biases = zip(*weights_and_biases)\n  if batch_norm_vars:\n    assert len(batch_norm_vars) == len(weights_and_biases) - 1\n    (bn_betas, bn_gammas, bn_means, bn_variances) = zip(*batch_norm_vars)\n  model_weights = {}\n\n  # Hidden layer weights.\n  for i in range(0, len(weights) - 1):\n    model_weights[HIDDEN_WEIGHTS_NAME_PATTERN % i] = weights[i]\n    model_weights[HIDDEN_BIASES_NAME_PATTERN % i] = biases[i]\n    if batch_norm_vars:\n      model_weights[BATCH_NORM_BETA_NAME_PATTERN % (i, i)] = bn_betas[i]\n      model_weights[BATCH_NORM_GAMMA_NAME_PATTERN % (i, i)] = bn_gammas[i]\n      model_weights[BATCH_NORM_MEAN_NAME_PATTERN % (i, i)] = bn_means[i]\n      model_weights[BATCH_NORM_VARIANCE_NAME_PATTERN % (i, i)] = bn_variances[i]\n\n  # Output layer weights.\n  model_weights[LOGITS_WEIGHTS_NAME] = weights[-1]\n  model_weights[LOGITS_BIASES_NAME] = biases[-1]\n\n  with tf.Graph().as_default():\n    # Create model variables.\n    for k, v in six.iteritems(model_weights):\n      tf.Variable(v, name=k, dtype=tf.dtypes.float32)\n\n    # Create non-model variables.\n    global_step_var = tf.compat.v1.train.create_global_step()\n\n    # Initialize vars and save checkpoint.\n    with tf.compat.v1.Session() as sess:\n      tf.compat.v1.initializers.global_variables().run()\n      global_step_var.assign(global_step).eval()\n      tf.compat.v1.train.Saver().save(sess,\n                                      os.path.join(model_dir, 'model.ckpt'))\n\n\ndef mock_head(testcase, hidden_units, logits_dimension, expected_logits):\n  \"\"\"Returns a mock head that validates logits values and variable names.\"\"\"\n  hidden_weights_names = [(HIDDEN_WEIGHTS_NAME_PATTERN + ':0') % i\n                          for i in range(len(hidden_units))]\n  hidden_biases_names = [\n      (HIDDEN_BIASES_NAME_PATTERN + ':0') % i for i in range(len(hidden_units))\n  ]\n  expected_var_names = (\n      hidden_weights_names + hidden_biases_names +\n      [LOGITS_WEIGHTS_NAME + ':0', LOGITS_BIASES_NAME + ':0'])\n\n  def _create_tpu_estimator_spec(features,\n                                 mode,\n                                 logits,\n                                 labels,\n                                 trainable_variables=None,\n                                 train_op_fn=None,\n                                 optimizer=None,\n                                 update_ops=None):\n    del features, labels  # Not used.\n    trainable_vars = tf.compat.v1.get_collection(\n        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n    testcase.assertItemsEqual(expected_var_names,\n                              [var.name for var in trainable_vars])\n    loss = tf.constant(1.)\n    assert_logits = assert_close(\n        expected_logits, logits, message='Failed for mode={}. '.format(mode))\n    with tf.control_dependencies([assert_logits]):\n      if mode == ModeKeys.TRAIN:\n        if train_op_fn is not None:\n          train_op = train_op_fn(loss)\n        elif optimizer is not None:\n          train_op = optimizer.get_updates(loss, trainable_variables)\n        if update_ops is not None:\n          train_op = tf.group(train_op, *update_ops)\n        return model_fn._TPUEstimatorSpec(\n            mode=mode, loss=loss, train_op=train_op)\n      elif mode == ModeKeys.EVAL:\n        return model_fn._TPUEstimatorSpec(mode=mode, loss=tf.identity(loss))\n      elif mode == ModeKeys.PREDICT:\n        return model_fn._TPUEstimatorSpec(\n            mode=mode, predictions={'logits': tf.identity(logits)})\n      else:\n        testcase.fail('Invalid mode: {}'.format(mode))\n\n  def _create_estimator_spec(features,\n                             mode,\n                             logits,\n                             labels,\n                             trainable_variables=None,\n                             train_op_fn=None,\n                             optimizer=None,\n                             update_ops=None):\n    tpu_spec = _create_tpu_estimator_spec(features, mode, logits, labels,\n                                          trainable_variables, train_op_fn,\n                                          optimizer, update_ops)\n    return tpu_spec.as_estimator_spec()\n\n  head = tf.compat.v1.test.mock.NonCallableMagicMock(spec=base_head.Head)\n  head.logits_dimension = logits_dimension\n  head._create_tpu_estimator_spec = tf.compat.v1.test.mock.MagicMock(\n      wraps=_create_tpu_estimator_spec)\n  head.create_estimator_spec = tf.compat.v1.test.mock.MagicMock(\n      wraps=_create_estimator_spec)\n\n  return head\n\n\ndef mock_optimizer(testcase, hidden_units, expected_loss=None):\n  \"\"\"Creates a mock optimizer to test the train method.\n\n  Args:\n    testcase: A TestCase instance.\n    hidden_units: Iterable of integer sizes for the hidden layers.\n    expected_loss: If given, will assert the loss value.\n\n  Returns:\n    A mock Optimizer.\n  \"\"\"\n  hidden_weights_names = [(HIDDEN_WEIGHTS_NAME_PATTERN + ':0') % i\n                          for i in range(len(hidden_units))]\n  hidden_biases_names = [\n      (HIDDEN_BIASES_NAME_PATTERN + ':0') % i for i in range(len(hidden_units))\n  ]\n  expected_var_names = (\n      hidden_weights_names + hidden_biases_names +\n      [LOGITS_WEIGHTS_NAME + ':0', LOGITS_BIASES_NAME + ':0'])\n\n  class _Optimizer(tf_keras.optimizers.legacy.Optimizer):\n\n    def get_updates(self, loss, params):\n      trainable_vars = params\n      testcase.assertItemsEqual(expected_var_names,\n                                [var.name for var in trainable_vars])\n\n      # Verify loss. We can't check the value directly, so we add an assert op.\n      testcase.assertEquals(0, loss.shape.ndims)\n      if expected_loss is None:\n        if self.iterations is not None:\n          return [self.iterations.assign_add(1).op]\n        return [tf.no_op()]\n      assert_loss = assert_close(\n          tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n          loss,\n          name='assert_loss')\n      with tf.control_dependencies((assert_loss,)):\n        if self.iterations is not None:\n          return [self.iterations.assign_add(1).op]\n        return [tf.no_op()]\n\n    def get_config(self):\n      config = super(_Optimizer, self).get_config()\n      return config\n\n  optimizer = _Optimizer(name='my_optimizer')\n\n  return optimizer\n\n\nclass BaseDNNModelFnTest(object):\n  \"\"\"Tests that _dnn_model_fn passes expected logits to mock head.\"\"\"\n\n  def __init__(self, dnn_model_fn, fc_impl=feature_column_v2):\n    self._dnn_model_fn = dnn_model_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_logits(self, mode, hidden_units, logits_dimension, inputs,\n                   expected_logits):\n    \"\"\"Tests that the expected logits are passed to mock head.\"\"\"\n    with tf.Graph().as_default():\n      tf.compat.v1.train.create_global_step()\n      head = mock_head(\n          self,\n          hidden_units=hidden_units,\n          logits_dimension=logits_dimension,\n          expected_logits=expected_logits)\n      estimator_spec = self._dnn_model_fn(\n          features={'age': tf.constant(inputs)},\n          labels=tf.constant([[1]]),\n          mode=mode,\n          head=head,\n          hidden_units=hidden_units,\n          feature_columns=[\n              self._fc_impl.numeric_column(\n                  'age', shape=np.array(inputs).shape[1:])\n          ],\n          optimizer=mock_optimizer(self, hidden_units))\n      with tf.compat.v1.train.MonitoredTrainingSession(\n          checkpoint_dir=self._model_dir) as sess:\n        if mode == ModeKeys.TRAIN:\n          sess.run(estimator_spec.train_op)\n        elif mode == ModeKeys.EVAL:\n          sess.run(estimator_spec.loss)\n        elif mode == ModeKeys.PREDICT:\n          sess.run(estimator_spec.predictions)\n        else:\n          self.fail('Invalid mode: {}'.format(mode))\n\n  def test_one_dim_logits(self):\n    \"\"\"Tests one-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]\n                   = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]\n    logits = [[-1*2.38 +1*0 +0.3]] = [[-2.08]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=1,\n          inputs=[[10.]],\n          expected_logits=[[-2.08]])\n\n  def test_multi_dim_logits(self):\n    \"\"\"Tests multi-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]\n                   = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]\n    logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38]]\n           = [[-2.08, 2.08, 1.19]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10.]],\n          expected_logits=[[-2.08, 2.08, 1.19]])\n\n  def test_multi_example_multi_dim_logits(self):\n    \"\"\"Tests multiple examples and multi-dimensional logits.\n\n    input_layer = [[10], [5]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)],\n                      [relu(0.6*5 +0.1), relu(0.5*5 -0.1)]]\n                   = [[6.1, 4.9], [3.1, 2.4]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)],\n                      [relu(1*3.1 -0.8*2.4 +0.2), relu(0.8*3.1 -1*2.4 -0.1)]]\n                   = [[2.38, 0], [1.38, 0]]\n    logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38],\n              [-1*1.38 +0.3, 1*1.38 -0.3, 0.5*1.38]]\n           = [[-2.08, 2.08, 1.19], [-1.08, 1.08, 0.69]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10.], [5.]],\n          expected_logits=[[-2.08, 2.08, 1.19], [-1.08, 1.08, .69]])\n\n  def test_multi_dim_input_one_dim_logits(self):\n    \"\"\"Tests multi-dimensional inputs and one-dimensional logits.\n\n    input_layer = [[10, 8]]\n    hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]\n                   = [[1.3, 0.9]]\n    hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]\n                   = [[0.78, relu(-0.06)]] = [[0.78, 0]]\n    logits = [[-1*0.78 +1*0 +0.3]] = [[-0.48]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=1,\n          inputs=[[10., 8.]],\n          expected_logits=[[-0.48]])\n\n  def test_multi_dim_input_multi_dim_logits(self):\n    \"\"\"Tests multi-dimensional inputs and multi-dimensional logits.\n\n    input_layer = [[10, 8]]\n    hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]\n                   = [[1.3, 0.9]]\n    hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]\n                   = [[0.78, relu(-0.06)]] = [[0.78, 0]]\n    logits = [[-1*0.78 + 0.3, 1*0.78 -0.3, 0.5*0.78]] = [[-0.48, 0.48, 0.39]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10., 8.]],\n          expected_logits=[[-0.48, 0.48, 0.39]])\n\n  def test_multi_feature_column_multi_dim_logits(self):\n    \"\"\"Tests multiple feature columns and multi-dimensional logits.\n\n    All numbers are the same as test_multi_dim_input_multi_dim_logits. The only\n    difference is that the input consists of two 1D feature columns, instead of\n    one 2D feature column.\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[-0.48, 0.48, 0.39]]\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      with tf.Graph().as_default():\n        tf.compat.v1.train.create_global_step()\n        head = mock_head(\n            self,\n            hidden_units=hidden_units,\n            logits_dimension=logits_dimension,\n            expected_logits=expected_logits)\n        estimator_spec = self._dnn_model_fn(\n            features={\n                'age': tf.constant(inputs[0]),\n                'height': tf.constant(inputs[1])\n            },\n            labels=tf.constant([[1]]),\n            mode=mode,\n            head=head,\n            hidden_units=hidden_units,\n            feature_columns=[\n                self._fc_impl.numeric_column('age'),\n                self._fc_impl.numeric_column('height')\n            ],\n            optimizer=mock_optimizer(self, hidden_units))\n        with tf.compat.v1.train.MonitoredTrainingSession(\n            checkpoint_dir=self._model_dir) as sess:\n          if mode == ModeKeys.TRAIN:\n            sess.run(estimator_spec.train_op)\n          elif mode == ModeKeys.EVAL:\n            sess.run(estimator_spec.loss)\n          elif mode == ModeKeys.PREDICT:\n            sess.run(estimator_spec.predictions)\n          else:\n            self.fail('Invalid mode: {}'.format(mode))\n\n  def test_multi_feature_column_mix_multi_dim_logits(self):\n    \"\"\"Tests multiple feature columns and multi-dimensional logits.\n\n    All numbers are the same as test_multi_dim_input_multi_dim_logits. The only\n    difference is that the input consists of two 1D feature columns, instead of\n    one 2D feature column.\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[-0.48, 0.48, 0.39]]\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      with tf.Graph().as_default():\n        tf.compat.v1.train.create_global_step()\n        head = mock_head(\n            self,\n            hidden_units=hidden_units,\n            logits_dimension=logits_dimension,\n            expected_logits=expected_logits)\n        estimator_spec = self._dnn_model_fn(\n            features={\n                'age': tf.constant(inputs[0]),\n                'height': tf.constant(inputs[1])\n            },\n            labels=tf.constant([[1]]),\n            mode=mode,\n            head=head,\n            hidden_units=hidden_units,\n            feature_columns=[\n                tf.feature_column.numeric_column('age'),\n                tf.feature_column.numeric_column('height')\n            ],\n            optimizer=mock_optimizer(self, hidden_units))\n        with tf.compat.v1.train.MonitoredTrainingSession(\n            checkpoint_dir=self._model_dir) as sess:\n          if mode == ModeKeys.TRAIN:\n            sess.run(estimator_spec.train_op)\n          elif mode == ModeKeys.EVAL:\n            sess.run(estimator_spec.loss)\n          elif mode == ModeKeys.PREDICT:\n            sess.run(estimator_spec.predictions)\n          else:\n            self.fail('Invalid mode: {}'.format(mode))\n\n  def test_features_tensor_raises_value_error(self):\n    \"\"\"Tests that passing a Tensor for features raises a ValueError.\"\"\"\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[0, 0, 0]]\n\n    with tf.Graph().as_default():\n      tf.compat.v1.train.create_global_step()\n      head = mock_head(\n          self,\n          hidden_units=hidden_units,\n          logits_dimension=logits_dimension,\n          expected_logits=expected_logits)\n      with self.assertRaisesRegexp(ValueError, 'features should be a dict'):\n        self._dnn_model_fn(\n            features=tf.constant(inputs),\n            labels=tf.constant([[1]]),\n            mode=ModeKeys.TRAIN,\n            head=head,\n            hidden_units=hidden_units,\n            feature_columns=[\n                self._fc_impl.numeric_column(\n                    'age', shape=np.array(inputs).shape[1:])\n            ],\n            optimizer=mock_optimizer(self, hidden_units))\n\n\nclass BaseDNNLogitFnTest(object):\n  \"\"\"Tests correctness of logits calculated from _dnn_logit_fn_builder.\"\"\"\n\n  def __init__(self, dnn_logit_fn_builder, fc_impl=feature_column_v2):\n    self._dnn_logit_fn_builder = dnn_logit_fn_builder\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_logits(self,\n                   mode,\n                   hidden_units,\n                   logits_dimension,\n                   inputs,\n                   expected_logits,\n                   batch_norm=False):\n    \"\"\"Tests that the expected logits are calculated.\"\"\"\n    with tf.Graph().as_default():\n      # Global step needed for MonitoredSession, which is in turn used to\n      # explicitly set variable weights through a checkpoint.\n      tf.compat.v1.train.create_global_step()\n      logit_fn = self._dnn_logit_fn_builder(\n          units=logits_dimension,\n          hidden_units=hidden_units,\n          feature_columns=[\n              self._fc_impl.numeric_column(\n                  'age', shape=np.array(inputs).shape[1:])\n          ],\n          activation_fn=tf.nn.relu,\n          dropout=None,\n          batch_norm=batch_norm)\n      logits = logit_fn(features={'age': tf.constant(inputs)}, mode=mode)\n      with tf.compat.v1.train.MonitoredTrainingSession(\n          checkpoint_dir=self._model_dir) as sess:\n        self.assertAllClose(expected_logits, sess.run(logits))\n\n  def test_one_dim_logits(self):\n    \"\"\"Tests one-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]\n                   = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]\n    logits = [[-1*2.38 +1*0 +0.3]] = [[-2.08]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=1,\n          inputs=[[10.]],\n          expected_logits=[[-2.08]])\n\n  def test_one_dim_logits_with_batch_norm(self):\n    \"\"\"Tests one-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +1), relu(0.5*10 -1)]] = [[7, 4]]\n    hidden_layer_0 = [[relu(0.6*20 +1), relu(0.5*20 -1)]] = [[13, 9]]\n\n    batch_norm_0, training (epsilon = 0.001):\n      mean1 = 1/2*(7+13) = 10,\n      variance1 = 1/2*(3^2+3^2) = 9\n      x11 = (7-10)/sqrt(9+0.001) = -0.999944449,\n      x21 = (13-10)/sqrt(9+0.001) = 0.999944449,\n\n      mean2 = 1/2*(4+9) = 6.5,\n      variance2 = 1/2*(2.5^2+.2.5^2) = 6.25\n      x12 = (4-6.5)/sqrt(6.25+0.001) = -0.99992001,\n      x22 = (9-6.5)/sqrt(6.25+0.001) = 0.99992001,\n\n    logits = [[-1*(-0.999944449) + 2*(-0.99992001) + 0.3],\n              [-1*0.999944449 + 2*0.99992001 + 0.3]]\n           = [[-0.699895571],[1.299895571]]\n\n    batch_norm_0, not training (epsilon = 0.001):\n      moving_mean1 = 0, moving_variance1 = 1\n      x11 = (7-0)/sqrt(1+0.001) = 6.996502623,\n      x21 = (13-0)/sqrt(1+0.001) = 12.993504871,\n      moving_mean2 = 0, moving_variance2 = 1\n      x12 = (4-0)/sqrt(1+0.001) = 3.998001499,\n      x22 = (9-0)/sqrt(1+0.001) = 8.995503372,\n\n    logits = [[-1*6.996502623 + 2*3.998001499 + 0.3],\n              [-1*12.993504871 + 2*8.995503372 + 0.3]]\n           = [[1.299500375],[5.297501873]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint(\n        (\n            ([[.6, .5]], [1., -1.]),\n            ([[-1.], [2.]], [.3]),\n        ),\n        base_global_step,\n        self._model_dir,\n        batch_norm_vars=(\n            [\n                [0, 0],  # beta.\n                [1, 1],  # gamma.\n                [0, 0],  # moving mean.\n                [1, 1],  # moving variance.\n            ],))\n    self._test_logits(\n        ModeKeys.TRAIN,\n        hidden_units=[2],\n        logits_dimension=1,\n        inputs=[[10.], [20.]],\n        expected_logits=[[-0.699895571], [1.299895571]],\n        batch_norm=True)\n    for mode in [ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=[2],\n          logits_dimension=1,\n          inputs=[[10.], [20.]],\n          expected_logits=[[1.299500375], [5.297501873]],\n          batch_norm=True)\n\n  def test_multi_dim_logits(self):\n    \"\"\"Tests multi-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]\n                   = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]\n    logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38]]\n           = [[-2.08, 2.08, 1.19]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10.]],\n          expected_logits=[[-2.08, 2.08, 1.19]])\n\n  def test_multi_example_multi_dim_logits(self):\n    \"\"\"Tests multiple examples and multi-dimensional logits.\n\n    input_layer = [[10], [5]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)],\n                      [relu(0.6*5 +0.1), relu(0.5*5 -0.1)]]\n                   = [[6.1, 4.9], [3.1, 2.4]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)],\n                      [relu(1*3.1 -0.8*2.4 +0.2), relu(0.8*3.1 -1*2.4 -0.1)]]\n                   = [[2.38, 0], [1.38, 0]]\n    logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38],\n              [-1*1.38 +0.3, 1*1.38 -0.3, 0.5*1.38]]\n           = [[-2.08, 2.08, 1.19], [-1.08, 1.08, 0.69]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10.], [5.]],\n          expected_logits=[[-2.08, 2.08, 1.19], [-1.08, 1.08, .69]])\n\n  def test_multi_dim_input_one_dim_logits(self):\n    \"\"\"Tests multi-dimensional inputs and one-dimensional logits.\n\n    input_layer = [[10, 8]]\n    hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]\n                   = [[1.3, 0.9]]\n    hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]\n                   = [[0.78, relu(-0.06)]] = [[0.78, 0]]\n    logits = [[-1*0.78 +1*0 +0.3]] = [[-0.48]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=1,\n          inputs=[[10., 8.]],\n          expected_logits=[[-0.48]])\n\n  def test_multi_dim_input_multi_dim_logits(self):\n    \"\"\"Tests multi-dimensional inputs and multi-dimensional logits.\n\n    input_layer = [[10, 8]]\n    hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]\n                   = [[1.3, 0.9]]\n    hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]\n                   = [[0.78, relu(-0.06)]] = [[0.78, 0]]\n    logits = [[-1*0.78 + 0.3, 1*0.78 -0.3, 0.5*0.78]] = [[-0.48, 0.48, 0.39]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10., 8.]],\n          expected_logits=[[-0.48, 0.48, 0.39]])\n\n  def test_multi_feature_column_multi_dim_logits(self):\n    \"\"\"Tests multiple feature columns and multi-dimensional logits.\n\n    All numbers are the same as test_multi_dim_input_multi_dim_logits. The only\n    difference is that the input consists of two 1D feature columns, instead of\n    one 2D feature column.\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[-0.48, 0.48, 0.39]]\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      with tf.Graph().as_default():\n        # Global step needed for MonitoredSession, which is in turn used to\n        # explicitly set variable weights through a checkpoint.\n        tf.compat.v1.train.create_global_step()\n        logit_fn = self._dnn_logit_fn_builder(\n            units=logits_dimension,\n            hidden_units=hidden_units,\n            feature_columns=[\n                self._fc_impl.numeric_column('age'),\n                self._fc_impl.numeric_column('height')\n            ],\n            activation_fn=tf.nn.relu,\n            dropout=None,\n            batch_norm=False)\n        logits = logit_fn(\n            features={\n                'age': tf.constant(inputs[0]),\n                'height': tf.constant(inputs[1])\n            },\n            mode=mode)\n        with tf.compat.v1.train.MonitoredTrainingSession(\n            checkpoint_dir=self._model_dir) as sess:\n          self.assertAllClose(expected_logits, sess.run(logits))\n\n  def test_multi_feature_column_mix_multi_dim_logits(self):\n    \"\"\"Tests multiple feature columns and multi-dimensional logits.\n\n    All numbers are the same as test_multi_dim_input_multi_dim_logits. The only\n    difference is that the input consists of two 1D feature columns, instead of\n    one 2D feature column.\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[-0.48, 0.48, 0.39]]\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      with tf.Graph().as_default():\n        # Global step needed for MonitoredSession, which is in turn used to\n        # explicitly set variable weights through a checkpoint.\n        tf.compat.v1.train.create_global_step()\n        logit_fn = self._dnn_logit_fn_builder(\n            units=logits_dimension,\n            hidden_units=hidden_units,\n            feature_columns=[\n                tf.feature_column.numeric_column('age'),\n                tf.feature_column.numeric_column('height')\n            ],\n            activation_fn=tf.nn.relu,\n            dropout=None,\n            batch_norm=False)\n        logits = logit_fn(\n            features={\n                'age': tf.constant(inputs[0]),\n                'height': tf.constant(inputs[1])\n            },\n            mode=mode)\n        with tf.compat.v1.train.MonitoredTrainingSession(\n            checkpoint_dir=self._model_dir) as sess:\n          self.assertAllClose(expected_logits, sess.run(logits))\n\n\nclass BaseDNNWarmStartingTest(object):\n\n  def __init__(self,\n               _dnn_classifier_fn,\n               _dnn_regressor_fn,\n               fc_impl=feature_column_v2):\n    self._dnn_classifier_fn = _dnn_classifier_fn\n    self._dnn_regressor_fn = _dnn_regressor_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    # Create a directory to save our old checkpoint and vocabularies to.\n    self._ckpt_and_vocab_dir = tempfile.mkdtemp()\n    # Reset the default graph in each test method to avoid the Keras optimizer\n    # naming issue during warm starting.\n    tf.compat.v1.reset_default_graph()\n\n    # Make a dummy input_fn.\n    def _input_fn():\n      features = {\n          'city': [['Palo Alto'], ['Mountain View']],\n          'locality': [['Palo Alto'], ['Mountain View']],\n          'occupation': [['doctor'], ['consultant']]\n      }\n      return features, [0, 1]\n\n    self._input_fn = _input_fn\n\n  def tearDown(self):\n    # Clean up checkpoint / vocab dir.\n    tf.compat.v1.summary.FileWriterCache.clear()\n    shutil.rmtree(self._ckpt_and_vocab_dir)\n\n  def assertAllNotClose(self, t1, t2):\n    \"\"\"Helper assert for arrays.\"\"\"\n    sum_of_abs_diff = 0.0\n    for x, y in zip(t1, t2):\n      try:\n        for a, b in zip(x, y):\n          sum_of_abs_diff += abs(b - a)\n      except TypeError:\n        sum_of_abs_diff += abs(y - x)\n    self.assertGreater(sum_of_abs_diff, 0)\n\n  def test_classifier_basic_warm_starting(self):\n    \"\"\"Tests correctness of DNNClassifier default warm-start.\"\"\"\n    city = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNClassifier and train to save a checkpoint.\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        n_classes=4,\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        warm_start_from=dnn_classifier.model_dir)\n\n    warm_started_dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_classifier.get_variable_names():\n      # Learning rate is also checkpointed in V2 optimizer. So we need to make\n      # sure it uses the new value after warm started.\n      if 'learning_rate' in variable_name:\n        self.assertAllClose(\n            0.0, warm_started_dnn_classifier.get_variable_value(variable_name))\n      else:\n        self.assertAllClose(\n            dnn_classifier.get_variable_value(variable_name),\n            warm_started_dnn_classifier.get_variable_value(variable_name))\n\n  def test_regressor_basic_warm_starting(self):\n    \"\"\"Tests correctness of DNNRegressor default warm-start.\"\"\"\n    city = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNRegressor and train to save a checkpoint.\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        model_dir=self._ckpt_and_vocab_dir,\n        optimizer='SGD')\n    dnn_regressor.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNRegressor, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        warm_start_from=dnn_regressor.model_dir)\n\n    warm_started_dnn_regressor.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_regressor.get_variable_names():\n      # Learning rate is also checkpointed in V2 optimizer. So we need to make\n      # sure it uses the new value after warm started.\n      if 'learning_rate' in variable_name:\n        self.assertAllClose(\n            0.0, warm_started_dnn_regressor.get_variable_value(variable_name))\n      else:\n        self.assertAllClose(\n            dnn_regressor.get_variable_value(variable_name),\n            warm_started_dnn_regressor.get_variable_value(variable_name))\n\n  def test_warm_starting_selective_variables(self):\n    \"\"\"Tests selecting variables to warm-start.\"\"\"\n    city = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNClassifier and train to save a checkpoint.\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        n_classes=4,\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        # The provided regular expression will only warm-start the city\n        # embedding, not the kernels and biases of the hidden weights.\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=dnn_classifier.model_dir,\n            vars_to_warm_start='.*(city).*'))\n\n    warm_started_dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_classifier.get_variable_names():\n      if 'city' in variable_name:\n        self.assertAllClose(\n            dnn_classifier.get_variable_value(variable_name),\n            warm_started_dnn_classifier.get_variable_value(variable_name))\n      elif 'bias' in variable_name:\n        # Hidden layer biases are zero-initialized.\n        bias_values = warm_started_dnn_classifier.get_variable_value(\n            variable_name)\n        self.assertAllClose(np.zeros_like(bias_values), bias_values)\n      elif 'kernel' in variable_name:\n        # We can't override the glorot uniform initializer used for the kernels\n        # in the dense layers, so just make sure we're not getting the same\n        # values from the old checkpoint.\n        self.assertAllNotClose(\n            dnn_classifier.get_variable_value(variable_name),\n            warm_started_dnn_classifier.get_variable_value(variable_name))\n\n  def test_warm_starting_with_vocab_remapping(self):\n    \"\"\"Tests warm-starting with vocab remapping.\"\"\"\n    vocab_list = ['doctor', 'lawyer', 'consultant']\n    vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')\n    with open(vocab_file, 'w') as f:\n      f.write('\\n'.join(vocab_list))\n    occupation = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_file(\n            'occupation',\n            vocabulary_file=vocab_file,\n            vocabulary_size=len(vocab_list)),\n        dimension=2)\n\n    # Create a DNNClassifier and train to save a checkpoint.\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[occupation],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).  Use a new FeatureColumn with a\n    # different vocabulary for occupation.\n    new_vocab_list = ['doctor', 'consultant', 'engineer']\n    new_vocab_file = os.path.join(self._ckpt_and_vocab_dir,\n                                  'new_occupation_vocab')\n    with open(new_vocab_file, 'w') as f:\n      f.write('\\n'.join(new_vocab_list))\n    new_occupation = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_file(\n            'occupation',\n            vocabulary_file=new_vocab_file,\n            vocabulary_size=len(new_vocab_list)),\n        dimension=2)\n    # We can create our VocabInfo object from the new and old occupation\n    # FeatureColumn's.\n    occupation_vocab_info = estimator.VocabInfo(\n        new_vocab=new_occupation.categorical_column.vocabulary_file,\n        new_vocab_size=new_occupation.categorical_column.vocabulary_size,\n        num_oov_buckets=new_occupation.categorical_column.num_oov_buckets,\n        old_vocab=occupation.categorical_column.vocabulary_file,\n        old_vocab_size=occupation.categorical_column.vocabulary_size,\n        # Can't use constant_initializer with load_and_remap.  In practice,\n        # use a truncated normal initializer.\n        backup_initializer=tf.compat.v1.initializers.random_uniform(\n            minval=0.39, maxval=0.39))\n    warm_started_dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[occupation],\n        n_classes=4,\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=dnn_classifier.model_dir,\n            var_name_to_vocab_info={\n                OCCUPATION_EMBEDDING_NAME: occupation_vocab_info\n            },\n            # Explicitly providing None here will only warm-start variables\n            # referenced in var_name_to_vocab_info (no hidden weights will be\n            # warmstarted).\n            vars_to_warm_start=None))\n\n    warm_started_dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n    # 'doctor' was ID-0 and still ID-0.\n    self.assertAllClose(\n        dnn_classifier.get_variable_value(OCCUPATION_EMBEDDING_NAME)[0, :],\n        warm_started_dnn_classifier.get_variable_value(\n            OCCUPATION_EMBEDDING_NAME)[0, :])\n    # 'consultant' was ID-2 and now ID-1.\n    self.assertAllClose(\n        dnn_classifier.get_variable_value(OCCUPATION_EMBEDDING_NAME)[2, :],\n        warm_started_dnn_classifier.get_variable_value(\n            OCCUPATION_EMBEDDING_NAME)[1, :])\n    # 'engineer' is a new entry and should be initialized with the\n    # backup_initializer in VocabInfo.\n    self.assertAllClose([0.39] * 2,\n                        warm_started_dnn_classifier.get_variable_value(\n                            OCCUPATION_EMBEDDING_NAME)[2, :])\n    for variable_name in warm_started_dnn_classifier.get_variable_names():\n      if 'bias' in variable_name:\n        # Hidden layer biases are zero-initialized.\n        bias_values = warm_started_dnn_classifier.get_variable_value(\n            variable_name)\n        self.assertAllClose(np.zeros_like(bias_values), bias_values)\n      elif 'kernel' in variable_name:\n        # We can't override the glorot uniform initializer used for the kernels\n        # in the dense layers, so just make sure we're not getting the same\n        # values from the old checkpoint.\n        self.assertAllNotClose(\n            dnn_classifier.get_variable_value(variable_name),\n            warm_started_dnn_classifier.get_variable_value(variable_name))\n\n  def test_warm_starting_with_naming_change(self):\n    \"\"\"Tests warm-starting with a Tensor name remapping.\"\"\"\n    locality = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'locality', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNClassifier and train to save a checkpoint.\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[locality],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    city = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n    warm_started_dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        n_classes=4,\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        # The 'city' variable correspond to the 'locality' variable in the\n        # previous model.\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=dnn_classifier.model_dir,\n            var_name_to_prev_var_name={\n                CITY_EMBEDDING_NAME:\n                    CITY_EMBEDDING_NAME.replace('city', 'locality')\n            }))\n\n    warm_started_dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_classifier.get_variable_names():\n      if 'city' in variable_name:\n        self.assertAllClose(\n            dnn_classifier.get_variable_value(\n                CITY_EMBEDDING_NAME.replace('city', 'locality')),\n            warm_started_dnn_classifier.get_variable_value(CITY_EMBEDDING_NAME))\n      # Learning rate is also checkpointed in V2 optimizer. So we need to make\n      # sure it uses the new value after warm started.\n      elif 'learning_rate' in variable_name:\n        self.assertAllClose(\n            0.0, warm_started_dnn_classifier.get_variable_value(variable_name))\n      else:\n        self.assertAllClose(\n            dnn_classifier.get_variable_value(variable_name),\n            warm_started_dnn_classifier.get_variable_value(variable_name))\n\n\nclass BaseDNNClassifierEvaluateTest(object):\n\n  def __init__(self, dnn_classifier_fn, fc_impl=feature_column_v2):\n    self._dnn_classifier_fn = dnn_classifier_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_one_dim(self):\n    \"\"\"Asserts evaluation metrics for one-dimensional input and logits.\"\"\"\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), global_step, self._model_dir)\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age')],\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      # batch_size = 2, one false label, and one true.\n      return {'age': [[10.], [10.]]}, [[1], [0]]\n\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-2.08], [-2.08]] =>\n    # logistic = 1/(1 + exp(-logits)) = [[0.11105597], [0.11105597]]\n    # loss = (-1. * log(0.111) -1. * log(0.889) = 2.31544200) / 2\n    expected_loss = 1.157721\n    self.assertAllClose(\n        {\n            metric_keys.MetricKeys.LOSS:\n                expected_loss,\n            metric_keys.MetricKeys.LOSS_MEAN:\n                expected_loss,\n            metric_keys.MetricKeys.ACCURACY:\n                0.5,\n            metric_keys.MetricKeys.PRECISION:\n                0.0,\n            metric_keys.MetricKeys.RECALL:\n                0.0,\n            metric_keys.MetricKeys.PREDICTION_MEAN:\n                0.11105597,\n            metric_keys.MetricKeys.LABEL_MEAN:\n                0.5,\n            metric_keys.MetricKeys.ACCURACY_BASELINE:\n                0.5,\n            # There is no good way to calculate AUC for only two data points.\n            # But that is what the algorithm returns.\n            metric_keys.MetricKeys.AUC:\n                0.5,\n            metric_keys.MetricKeys.AUC_PR:\n                0.5,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP:\n                global_step\n        },\n        dnn_classifier.evaluate(input_fn=_input_fn, steps=1))\n\n  def test_multi_dim(self):\n    \"\"\"Asserts evaluation metrics for multi-dimensional input and logits.\"\"\"\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), global_step, self._model_dir)\n    n_classes = 3\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      # batch_size = 2, one false label, and one true.\n      return {'age': [[10., 8.], [10., 8.]]}, [[1], [0]]\n\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-0.48, 0.48, 0.39], [-0.48, 0.48, 0.39]]\n    # probabilities = exp(logits)/sum(exp(logits))\n    #               = [[0.16670536, 0.43538380, 0.39791084],\n    #                  [0.16670536, 0.43538380, 0.39791084]]\n    # loss = -log(0.43538380) - log(0.16670536)\n    expected_loss = 2.62305466 / 2  # batch size\n    self.assertAllClose(\n        {\n            metric_keys.MetricKeys.LOSS: expected_loss,\n            metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n            metric_keys.MetricKeys.ACCURACY: 0.5,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step\n        }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1))\n\n  def test_float_labels(self):\n    \"\"\"Asserts evaluation metrics for float labels in binary classification.\"\"\"\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), global_step, self._model_dir)\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age')],\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      # batch_size = 2, one false label, and one true.\n      return {'age': [[10.], [10.]]}, [[0.8], [0.4]]\n\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-2.08], [-2.08]] =>\n    # logistic = 1/(1 + exp(-logits)) = [[0.11105597], [0.11105597]]\n    # loss = (-0.8 * log(0.111) -0.2 * log(0.889)\n    #        -0.4 * log(0.111) -0.6 * log(0.889)) / 2 = 2.7314420 / 2\n    expected_loss = 1.365721\n    metrics = dnn_classifier.evaluate(input_fn=_input_fn, steps=1)\n    self.assertAlmostEqual(expected_loss, metrics[metric_keys.MetricKeys.LOSS])\n\n  def test_multi_dim_weights(self):\n    \"\"\"Tests evaluation with weights.\"\"\"\n    # Uses same checkpoint with test_multi_dims\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), global_step, self._model_dir)\n    n_classes = 3\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],\n        n_classes=n_classes,\n        weight_column='w',\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      # batch_size = 2, one false label, and one true.\n      return {'age': [[10., 8.], [10., 8.]], 'w': [[10.], [100.]]}, [[1], [0]]\n\n    # Uses identical numbers as test_multi_dims\n    # See that test for calculation of logits.\n    # loss = (-log(0.43538380)*10 - log(0.16670536)*100) / 2\n    expected_loss = 93.734\n    metrics = dnn_classifier.evaluate(input_fn=_input_fn, steps=1)\n    self.assertAlmostEqual(\n        expected_loss, metrics[metric_keys.MetricKeys.LOSS], places=3)\n\n\nclass BaseDNNRegressorEvaluateTest(object):\n\n  def __init__(self, dnn_regressor_fn, fc_impl=feature_column_v2):\n    self._dnn_regressor_fn = dnn_regressor_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_one_dim(self):\n    \"\"\"Asserts evaluation metrics for one-dimensional input and logits.\"\"\"\n    # Create checkpoint: num_inputs=1, hidden_units=(2, 2), num_outputs=1.\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), global_step, self._model_dir)\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age')],\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      return {'age': [[10.]]}, [[1.]]\n\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-2.08]] => predictions = [-2.08].\n    # loss = (1+2.08)^2 = 9.4864\n    expected_loss = 9.4864\n    self.assertAllClose(\n        {\n            metric_keys.MetricKeys.LOSS: expected_loss,\n            metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n            metric_keys.MetricKeys.PREDICTION_MEAN: -2.08,\n            metric_keys.MetricKeys.LABEL_MEAN: 1.0,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step\n        }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))\n\n  def test_multi_dim(self):\n    \"\"\"Asserts evaluation metrics for multi-dimensional input and logits.\"\"\"\n    # Create checkpoint: num_inputs=2, hidden_units=(2, 2), num_outputs=3.\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), global_step, self._model_dir)\n    label_dimension = 3\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      return {'age': [[10., 8.]]}, [[1., -1., 0.5]]\n\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-0.48, 0.48, 0.39]]\n    # loss = (1+0.48)^2 + (-1-0.48)^2 + (0.5-0.39)^2 = 4.3929\n    # expected_loss = loss / 3\n    expected_loss = 1.4643\n    self.assertAllClose(\n        {\n            metric_keys.MetricKeys.LOSS: expected_loss,\n            metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 0.39 / 3.0,\n            metric_keys.MetricKeys.LABEL_MEAN: 0.5 / 3.0,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step\n        }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))\n\n  def test_multi_dim_weights(self):\n    \"\"\"Asserts evaluation metrics for multi-dimensional input and logits.\"\"\"\n    # same checkpoint with test_multi_dim.\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), global_step, self._model_dir)\n    label_dimension = 3\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],\n        label_dimension=label_dimension,\n        weight_column='w',\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      return {'age': [[10., 8.]], 'w': [10.]}, [[1., -1., 0.5]]\n\n    # Uses identical numbers as test_multi_dim.\n    # See that test for calculation of logits.\n    # loss = 4.3929*10/3\n    expected_loss = 14.643\n    metrics = dnn_regressor.evaluate(input_fn=_input_fn, steps=1)\n    self.assertAlmostEqual(\n        expected_loss, metrics[metric_keys.MetricKeys.LOSS], places=3)\n\n\nclass BaseDNNClassifierPredictTest(object):\n\n  def __init__(self, dnn_classifier_fn, fc_impl=feature_column_v2):\n    self._dnn_classifier_fn = dnn_classifier_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_one_dim(self, label_vocabulary, label_output_fn):\n    \"\"\"Asserts predictions for one-dimensional input and logits.\"\"\"\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ),\n                      global_step=0,\n                      model_dir=self._model_dir)\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        label_vocabulary=label_vocabulary,\n        feature_columns=(self._fc_impl.numeric_column('x'),),\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08] =>\n    # logistic = exp(-2.08)/(1 + exp(-2.08)) = 0.11105597\n    # probabilities = [1-logistic, logistic] = [0.88894403, 0.11105597]\n    # class_ids = argmax(probabilities) = [0]\n    predictions = next(dnn_classifier.predict(input_fn=input_fn))\n    self.assertAllClose([-2.08],\n                        predictions[prediction_keys.PredictionKeys.LOGITS])\n    self.assertAllClose([0.11105597],\n                        predictions[prediction_keys.PredictionKeys.LOGISTIC])\n    self.assertAllClose(\n        [0.88894403, 0.11105597],\n        predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n    self.assertAllClose([0],\n                        predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n    self.assertAllEqual([label_output_fn(0)],\n                        predictions[prediction_keys.PredictionKeys.CLASSES])\n    self.assertAllClose(\n        [0, 1], predictions[prediction_keys.PredictionKeys.ALL_CLASS_IDS])\n    self.assertAllEqual(\n        [label_output_fn(0), label_output_fn(1)],\n        predictions[prediction_keys.PredictionKeys.ALL_CLASSES])\n\n  def test_one_dim_without_label_vocabulary(self):\n    self._test_one_dim(\n        label_vocabulary=None, label_output_fn=lambda x: ('%s' % x).encode())\n\n  def test_one_dim_with_label_vocabulary(self):\n    n_classes = 2\n    self._test_one_dim(\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n  def _test_multi_dim_with_3_classes(self, label_vocabulary, label_output_fn):\n    \"\"\"Asserts predictions for multi-dimensional input and logits.\"\"\"\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ),\n                      global_step=0,\n                      model_dir=self._model_dir)\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),\n        label_vocabulary=label_vocabulary,\n        n_classes=3,\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        # Inputs shape is (batch_size, num_inputs).\n        x={'x': np.array([[10., 8.]])},\n        batch_size=1,\n        shuffle=False)\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-0.48, 0.48, 0.39] =>\n    # probabilities[i] = exp(logits[i]) / sum_j exp(logits[j]) =>\n    # probabilities = [0.16670536, 0.43538380, 0.39791084]\n    # class_ids = argmax(probabilities) = [1]\n    predictions = next(dnn_classifier.predict(input_fn=input_fn))\n    self.assertItemsEqual([\n        prediction_keys.PredictionKeys.LOGITS,\n        prediction_keys.PredictionKeys.PROBABILITIES,\n        prediction_keys.PredictionKeys.CLASS_IDS,\n        prediction_keys.PredictionKeys.CLASSES,\n        prediction_keys.PredictionKeys.ALL_CLASS_IDS,\n        prediction_keys.PredictionKeys.ALL_CLASSES\n    ], six.iterkeys(predictions))\n    self.assertAllClose([-0.48, 0.48, 0.39],\n                        predictions[prediction_keys.PredictionKeys.LOGITS])\n    self.assertAllClose(\n        [0.16670536, 0.43538380, 0.39791084],\n        predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n    self.assertAllEqual([1],\n                        predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n    self.assertAllEqual([label_output_fn(1)],\n                        predictions[prediction_keys.PredictionKeys.CLASSES])\n    self.assertAllEqual(\n        [0, 1, 2], predictions[prediction_keys.PredictionKeys.ALL_CLASS_IDS])\n    self.assertAllEqual(\n        [label_output_fn(0),\n         label_output_fn(1),\n         label_output_fn(2)],\n        predictions[prediction_keys.PredictionKeys.ALL_CLASSES])\n\n  def test_multi_dim_with_3_classes_but_no_label_vocab(self):\n    self._test_multi_dim_with_3_classes(\n        label_vocabulary=None, label_output_fn=lambda x: ('%s' % x).encode())\n\n  def test_multi_dim_with_3_classes_and_label_vocab(self):\n    n_classes = 3\n    self._test_multi_dim_with_3_classes(\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n\nclass BaseDNNRegressorPredictTest(object):\n\n  def __init__(self, dnn_regressor_fn, fc_impl=feature_column_v2):\n    self._dnn_regressor_fn = dnn_regressor_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_one_dim(self):\n    \"\"\"Asserts predictions for one-dimensional input and logits.\"\"\"\n    # Create checkpoint: num_inputs=1, hidden_units=(2, 2), num_outputs=1.\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ),\n                      global_step=0,\n                      model_dir=self._model_dir)\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=(self._fc_impl.numeric_column('x'),),\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-2.08]] => predictions = [-2.08].\n    self.assertAllClose({\n        prediction_keys.PredictionKeys.PREDICTIONS: [-2.08],\n    }, next(dnn_regressor.predict(input_fn=input_fn)))\n\n  def test_multi_dim(self):\n    \"\"\"Asserts predictions for multi-dimensional input and logits.\"\"\"\n    # Create checkpoint: num_inputs=2, hidden_units=(2, 2), num_outputs=3.\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), 100, self._model_dir)\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),\n        label_dimension=3,\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        # Inputs shape is (batch_size, num_inputs).\n        x={'x': np.array([[10., 8.]])},\n        batch_size=1,\n        shuffle=False)\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-0.48, 0.48, 0.39]] => predictions = [-0.48, 0.48, 0.39]\n    self.assertAllClose(\n        {\n            prediction_keys.PredictionKeys.PREDICTIONS: [-0.48, 0.48, 0.39],\n        }, next(dnn_regressor.predict(input_fn=input_fn)))\n\n\nclass _SummaryHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Saves summaries every N steps.\"\"\"\n\n  def __init__(self):\n    self._summaries = []\n\n  def begin(self):\n    self._summary_op = tf.compat.v1.summary.merge_all()\n\n  def before_run(self, run_context):\n    return tf.compat.v1.train.SessionRunArgs({'summary': self._summary_op})\n\n  def after_run(self, run_context, run_values):\n    s = tf.compat.v1.summary.Summary()\n    s.ParseFromString(run_values.results['summary'])\n    self._summaries.append(s)\n\n  def summaries(self):\n    return tuple(self._summaries)\n\n\ndef _assert_checkpoint(testcase, global_step, input_units, hidden_units,\n                       output_units, model_dir):\n  \"\"\"Asserts checkpoint contains expected variables with proper shapes.\n\n  Args:\n    testcase: A TestCase instance.\n    global_step: Expected global step value.\n    input_units: The dimension of input layer.\n    hidden_units: Iterable of integer sizes for the hidden layers.\n    output_units: The dimension of output layer (logits).\n    model_dir: The model directory.\n  \"\"\"\n  shapes = {name: shape for (name, shape) in tf.train.list_variables(model_dir)}\n\n  # Global step.\n  testcase.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n  testcase.assertEqual(\n      global_step,\n      tf.train.load_variable(model_dir, tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n  # Hidden layer weights.\n  prev_layer_units = input_units\n  for i in range(len(hidden_units)):\n    layer_units = hidden_units[i]\n    testcase.assertAllEqual((prev_layer_units, layer_units),\n                            shapes[HIDDEN_WEIGHTS_NAME_PATTERN % i])\n    testcase.assertAllEqual((layer_units,),\n                            shapes[HIDDEN_BIASES_NAME_PATTERN % i])\n    prev_layer_units = layer_units\n\n  # Output layer weights.\n  testcase.assertAllEqual((prev_layer_units, output_units),\n                          shapes[LOGITS_WEIGHTS_NAME])\n  testcase.assertAllEqual((output_units,), shapes[LOGITS_BIASES_NAME])\n\n\ndef _assert_simple_summary(testcase, expected_values, actual_summary):\n  \"\"\"Assert summary the specified simple values.\n\n  Args:\n    testcase: A TestCase instance.\n    expected_values: Dict of expected tags and simple values.\n    actual_summary: `summary_pb2.Summary`.\n  \"\"\"\n  testcase.assertAllClose(\n      expected_values, {\n          v.tag: v.simple_value\n          for v in actual_summary.value\n          if (v.tag in expected_values)\n      })\n\n\nclass BaseDNNClassifierTrainTest(object):\n\n  def __init__(self, dnn_classifier_fn, fc_impl=feature_column_v2):\n    self._dnn_classifier_fn = dnn_classifier_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_from_scratch_with_default_optimizer_binary(self):\n    hidden_units = (2, 2)\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate final checkpoint.\n    num_steps = 5\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1]]), steps=num_steps)\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n\n  def test_from_scratch_with_default_optimizer_multi_class(self):\n    hidden_units = (2, 2)\n    n_classes = 3\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate final checkpoint.\n    num_steps = 5\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[2]]), steps=num_steps)\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=n_classes,\n        model_dir=self._model_dir)\n\n  def test_from_scratch_validate_summary(self):\n    hidden_units = (2, 2)\n    opt = mock_optimizer(self, hidden_units=hidden_units)\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(num_steps,\n                     dnn_classifier.get_variable_value(opt.iterations.name))\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      summary_keys = [v.tag for v in summary.value]\n      self.assertIn(metric_keys.MetricKeys.LOSS, summary_keys)\n\n  def test_binary_classification(self):\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    # Uses identical numbers as DNNModelFnTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08] => probabilities = [0.889, 0.111]\n    # loss = -1. * log(0.111) = 2.19772100\n    expected_loss = 2.19772100\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(base_global_step + num_steps,\n                     dnn_classifier.get_variable_value(opt.iterations.name))\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      _assert_simple_summary(\n          self, {\n              'dnn/hiddenlayer_0/fraction_of_zero_values': 0.,\n              'dnn/hiddenlayer_1/fraction_of_zero_values': .5,\n              'dnn/logits/fraction_of_zero_values': 0.,\n              metric_keys.MetricKeys.LOSS: expected_loss,\n          }, summary)\n    _assert_checkpoint(\n        self,\n        base_global_step + num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n\n  def test_binary_classification_float_labels(self):\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    # Uses identical numbers as DNNModelFnTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08] => probabilities = [0.889, 0.111]\n    # loss = -0.8 * log(0.111) -0.2 * log(0.889) = 1.7817210\n    expected_loss = 1.7817210\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[0.8]]), steps=num_steps)\n    self.assertEqual(base_global_step + num_steps,\n                     dnn_classifier.get_variable_value(opt.iterations.name))\n\n  def test_multi_class(self):\n    n_classes = 3\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    # Uses identical numbers as DNNModelFnTest.test_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08, 2.08, 1.19] => probabilities = [0.0109, 0.7011, 0.2879]\n    # loss = -1. * log(0.7011) = 0.35505795\n    expected_loss = 0.35505795\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_classifier = self._dnn_classifier_fn(\n        n_classes=n_classes,\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(base_global_step + num_steps,\n                     dnn_classifier.get_variable_value(opt.iterations.name))\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      _assert_simple_summary(\n          self, {\n              'dnn/hiddenlayer_0/fraction_of_zero_values': 0.,\n              'dnn/hiddenlayer_1/fraction_of_zero_values': .5,\n              'dnn/logits/fraction_of_zero_values': 0.,\n              metric_keys.MetricKeys.LOSS: expected_loss,\n          }, summary)\n    _assert_checkpoint(\n        self,\n        base_global_step + num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=n_classes,\n        model_dir=self._model_dir)\n\n\nclass BaseDNNRegressorTrainTest(object):\n\n  def __init__(self, dnn_regressor_fn, fc_impl=feature_column_v2):\n    self._dnn_regressor_fn = dnn_regressor_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_from_scratch_with_default_optimizer(self):\n    hidden_units = (2, 2)\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate final checkpoint.\n    num_steps = 5\n    dnn_regressor.train(\n        input_fn=lambda: ({\n            'age': ((1,),)\n        }, ((10,),)), steps=num_steps)\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n\n  def test_from_scratch(self):\n    hidden_units = (2, 2)\n    opt = mock_optimizer(self, hidden_units=hidden_units)\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_regressor.train(\n        input_fn=lambda: ({\n            'age': ((1,),)\n        }, ((5.,),)),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(num_steps,\n                     dnn_regressor.get_variable_value(opt.iterations.name))\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      summary_keys = [v.tag for v in summary.value]\n      self.assertIn(metric_keys.MetricKeys.LOSS, summary_keys)\n\n  def test_one_dim(self):\n    \"\"\"Asserts train loss for one-dimensional input and logits.\"\"\"\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    # Uses identical numbers as DNNModelFnTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08] => predictions = [-2.08]\n    # loss = (1 + 2.08)^2 = 9.4864\n    expected_loss = 9.4864\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_regressor.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1.]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(base_global_step + num_steps,\n                     dnn_regressor.get_variable_value(opt.iterations.name))\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      _assert_simple_summary(\n          self, {\n              'dnn/hiddenlayer_0/fraction_of_zero_values': 0.,\n              'dnn/hiddenlayer_1/fraction_of_zero_values': 0.5,\n              'dnn/logits/fraction_of_zero_values': 0.,\n              metric_keys.MetricKeys.LOSS: expected_loss,\n          }, summary)\n    _assert_checkpoint(\n        self,\n        base_global_step + num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n\n  def test_multi_dim(self):\n    \"\"\"Asserts train loss for multi-dimensional input and logits.\"\"\"\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    input_dimension = 2\n    label_dimension = 3\n\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-0.48, 0.48, 0.39]]\n    # loss = (1+0.48)^2 + (-1-0.48)^2 + (0.5-0.39)^2 = 4.3929\n    # expected_loss = loss / 3 (batch size)\n    expected_loss = 1.4643\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=hidden_units,\n        feature_columns=[\n            self._fc_impl.numeric_column('age', shape=[input_dimension])\n        ],\n        label_dimension=label_dimension,\n        optimizer=opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_regressor.train(\n        input_fn=lambda: ({\n            'age': [[10., 8.]]\n        }, [[1., -1., 0.5]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(base_global_step + num_steps,\n                     dnn_regressor.get_variable_value(opt.iterations.name))\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      _assert_simple_summary(\n          self, {\n              'dnn/hiddenlayer_0/fraction_of_zero_values': 0.,\n              'dnn/hiddenlayer_1/fraction_of_zero_values': 0.5,\n              'dnn/logits/fraction_of_zero_values': 0.,\n              metric_keys.MetricKeys.LOSS: expected_loss,\n          }, summary)\n    _assert_checkpoint(\n        self,\n        base_global_step + num_steps,\n        input_units=input_dimension,\n        hidden_units=hidden_units,\n        output_units=label_dimension,\n        model_dir=self._model_dir)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/head.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Abstractions for the head(s) of a model.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\nimport collections\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import lookup_ops\nfrom tensorflow.python.ops import string_ops\nfrom tensorflow.python.util import function_utils\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n_DEFAULT_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n\n# The above default is defined by TF Serving, but these next three are just\n# a local convention without any special meaning.\n_CLASSIFY_SERVING_KEY = 'classification'\n_REGRESS_SERVING_KEY = 'regression'\n_PREDICT_SERVING_KEY = 'predict'\n\n# A LossSpec contains\n# * a scalar `Tensor` representing reduced weighted training loss\n# * a `Tensor` representing the unreduced unweighted loss\n# * a `Tensor` representing the example weights\n# * possibly processed labels (e.g. vocabulary lookup, shape manipulation, etc)\nLossSpec = collections.namedtuple(\n    'LossSpec',\n    ['training_loss', 'unreduced_loss', 'weights', 'processed_labels'])\n\n\ndef _summary_key(head_name, val):\n  return '%s/%s' % (val, head_name) if head_name else val\n\n\ndef _create_eval_metrics_tuple(fn, kwargs):\n  \"\"\"Creates TPU eval metrics tuple.\n\n  Helper function to make eval_metric tuple (eval_metric_fn, fn_kwargs) used\n  by `TPUEstimator`. TPUEstimator requires that `eval_metric_fn` take\n  exclusively Tensor arguments. This helper can help create such a function from\n  a more generic function that can take both Tensor and non-Tensor arguments.\n\n  Args:\n    fn: A eval_metric_fn that takes both Tensor and non-Tensor arguments. This\n      function must return a dict of form\n        {'metric name': (metric_tensor, eval_op)}\n    kwargs: Dict of arguments for `fn`.\n\n  Returns:\n    `eval_metric` tuple that can be passed to a `model_fn._TPUEstimatorSpec`.\n  \"\"\"\n  tensor_kwargs = {}\n  nontensor_kwargs = {}\n  for k, v in six.iteritems(kwargs):\n    if tf.is_tensor(v):\n      tensor_kwargs[k] = v\n    else:\n      nontensor_kwargs[k] = v\n\n  def _fn(**tensors):\n    return fn(**dict(nontensor_kwargs, **tensors))\n\n  return (_fn, tensor_kwargs)\n\n\nclass _Head(object):\n  \"\"\"Interface for the head/top of a model.\n\n  Given logits (or output of a hidden layer), a Head knows how to compute\n  predictions, loss, train_op, metrics and export outputs. It is meant to:\n\n  1. Simplify writing model_fn and to make model_fn more configurable\n  2. Support wide range of machine learning models. Since most heads can work\n     with logits, they can support DNN, RNN, Wide, Wide&Deep,\n     Global objectives, Gradient boosted trees and many other types\n     of machine learning models.\n\n  Common usage:\n  Here is simplified model_fn to build a DNN regression model.\n    ```python\n    def _my_dnn_model_fn(features, labels, mode, params, config=None):\n      # Optionally your callers can pass head to model_fn as a param.\n      head = tf.contrib.estimator.regression_head(...)\n      inputs = tf.feature_column.input_layer(features, ...)\n      hidden_layer0 = tf.layers.dense(\n          inputs, units=1000, activation=tf.nn.relu)\n      hidden_layer1 = tf.layers.dense(\n          hidden_layer0, units=500, activation=tf.nn.relu)\n      logits = tf.layers.dense(\n          hidden_layer1, units=head.logits_dimension, activation=None)\n\n      return head.create_estimator_spec(\n          features=features,\n          labels=labels,\n          mode=mode,\n          logits=logits,\n          optimizer=optimizer)\n    ```\n\n  There are cases where computing and applying gradients can not be meaningfully\n  captured with optimizer or train_op_fn we support (for example, with sync\n  optimizer). In such case, you can take the responsibility on your own. Here is\n  a common use case,\n    ```python\n    estimator_spec = head.create_estimator_spec(\n        features=features,\n        labels=labels,\n        mode=mode,\n        logits=logits,\n        train_op_fn=lambda _: tf.no_op())\n    if mode == ModeKeys.TRAIN:\n      optimizer = ...\n      sync = tf.train.SyncReplicasOptimizer(opt=optimizer, ...)\n      update_op = sync.minimize(\n          estimator_spec.loss, global_step=tf.get_global_step())\n      hooks = [sync.make_session_run_hook(is_chief)]\n      ... update train_op and hooks in EstimatorSpec and return\n    ```\n  \"\"\"\n  __metaclass__ = abc.ABCMeta\n\n  @abc.abstractproperty\n  def name(self):\n    \"\"\"The name of this head.\n\n    Returns:\n      A string.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  @abc.abstractproperty\n  def logits_dimension(self):\n    \"\"\"Size of the last dimension of the logits `Tensor`.\n\n    Typically, logits is of shape `[batch_size, logits_dimension]`.\n\n    Returns:\n      The expected size of the `logits` tensor.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  @abc.abstractmethod\n  def create_loss(self, features, mode, logits, labels):\n    \"\"\"Returns a loss Tensor from provided logits.\n\n    This function is designed to be used by framework developers.  Almost all\n    users should use create_estimator_spec(), which calls this internally.\n    `mode` and `features` are most likely not used, but some Head\n    implementations may require them.\n\n    Args:\n      features: Input `dict` of `Tensor` objects.\n      mode: Estimator's `ModeKeys`.\n      logits: logits `Tensor` to be used for loss construction.\n      labels: Labels `Tensor`, or `dict` of same.\n\n    Returns:\n      A LossSpec that contains\n      * the scalar `Tensor` representing reduced weighted training loss\n      * the `Tensor` representing the unreduced unweighted loss\n      * the `Tensor` representing the example weights\n      * possibly processed labels (e.g. vocabulary lookup, shape manipulation,\n        etc.)\n\n      To be extendable in the future.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  # TODO(b/65403806): By default, collect regularization_losses from\n  # GraphKeys.REGULARIZATION_LOSSES collection.\n  def create_estimator_spec(self,\n                            features,\n                            mode,\n                            logits,\n                            labels=None,\n                            optimizer=None,\n                            train_op_fn=None,\n                            regularization_losses=None):\n    \"\"\"Returns `EstimatorSpec` that a model_fn can return.\n\n    Please note that,\n    + All args must be passed via name.\n\n    Args:\n      features: Input `dict` of `Tensor` or `SparseTensor` objects.\n      mode: Estimator's `ModeKeys`.\n      logits: logits `Tensor` to be used by the head.\n      labels: Labels `Tensor`, or `dict` of same.\n      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.\n        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which\n        updates variables and increments `global_step`.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op\n        to optimize the model with the loss in TRAIN mode. Used if `optimizer`\n        is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in\n        TRAIN mode. None is allowed in other modes. If you want to optimize loss\n        yourself you can pass `lambda _: tf.no_op()` and then use\n          EstimatorSpec.loss to compute and apply gradients.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses.\n\n    Returns:\n      `EstimatorSpec`.\n    \"\"\"\n    try:\n      tpu_estimator_spec = (\n          self._create_tpu_estimator_spec(features, mode, logits, labels,\n                                          optimizer, train_op_fn,\n                                          regularization_losses))\n      return tpu_estimator_spec.as_estimator_spec()\n    except NotImplementedError:\n      # Not all subclasses of _Head will have implemented\n      # _create_tpu_estimator_spec. If it is implemented, we can use it to\n      # create our `EstimatorSpec` here.\n      raise NotImplementedError(\n          'Subclasses of _Head must implement `create_estimator_spec()` or '\n          '_create_tpu_estimator_spec().')\n\n  def _create_tpu_estimator_spec(self,\n                                 features,\n                                 mode,\n                                 logits,\n                                 labels=None,\n                                 optimizer=None,\n                                 train_op_fn=None,\n                                 regularization_losses=None):\n    \"\"\"Returns `model_fn._TPUEstimatorSpec` that a model_fn can return.\n\n    Args:\n      features: Input `dict` of `Tensor` or `SparseTensor` objects.\n      mode: Estimator's `ModeKeys`.\n      logits: logits `Tensor` to be used by the head.\n      labels: Labels `Tensor`, or `dict` of same.\n      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.\n        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which\n        updates variables and increments `global_step`.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op\n        to optimize the model with the loss in TRAIN mode. Used if `optimizer`\n        is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in\n        TRAIN mode. None is allowed in other modes. If you want to optimize loss\n        yourself you can pass `lambda _: tf.no_op()` and then use\n          EstimatorSpec.loss to compute and apply gradients.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses.\n\n    Returns:\n      A `model_fn._TPUEstimatorSpec' instance.\n    \"\"\"\n    raise NotImplementedError(\n        'TPUEstimatorSpec not available for this model head.')\n\n\ndef _check_dense_labels_match_logits_and_reshape(labels, logits,\n                                                 expected_labels_dimension):\n  \"\"\"Checks that labels shape matches logits and reshapes if needed.\n\n  Consider logits of shape [D0, D1, ... DN, logits_dimension]. Then labels\n  shape must be [D0, D1, ... DN, expected_labels_dimension].\n  If expected_labels_dimension=1, labels could be [D0, D1, ... DN] and this\n  method reshapes them to [D0, D1, ... DN, 1].\n\n  Args:\n    labels: labels Tensor.\n    logits: logits Tensor.\n    expected_labels_dimension: Integer.\n\n  Returns:\n    Validated and reshaped labels Tensor.\n  Raises:\n    ValueError: If labels is a SparseTensor.\n    ValueError: If labels shape is statically defined and fails validation.\n    OpError: If labels shape is not statically defined and fails validation.\n  \"\"\"\n  if labels is None:\n    raise ValueError(\n        'You must provide a labels Tensor. Given: None. '\n        'Suggested troubleshooting steps: Check that your data contain '\n        'your label feature. Check that your input_fn properly parses and '\n        'returns labels.')\n  with ops.name_scope(None, 'labels', (labels, logits)) as scope:\n    labels = tf.compat.v1.convert_to_tensor_or_sparse_tensor(labels)\n    if isinstance(labels, tf.sparse.SparseTensor):\n      raise ValueError(\n          'SparseTensor labels are not supported. '\n          'labels must be a Tensor of shape [D0, D1, ..., DN, %s], '\n          'e.g. [batch_size, %s]. '\n          'Suggested Fix (1): Check the label feature in your data. '\n          'Each example must contain %s value(s). If not, your choice of label '\n          'was probably incorrect. '\n          'Suggested Fix (2): In your input_fn, use '\n          'tf.sparse_tensor_to_dense() to turn labels into a Tensor.'\n          '' % (expected_labels_dimension, expected_labels_dimension,\n                expected_labels_dimension))\n    if (labels.shape.ndims is not None and logits.shape.ndims is not None and\n        labels.shape.ndims == logits.shape.ndims - 1):\n      labels = tf.compat.v1.expand_dims(labels, -1)\n    labels_shape = tf.compat.v1.shape(labels)\n    logits_shape = tf.compat.v1.shape(logits)\n    err_msg = (\n        'labels shape must be [D0, D1, ... DN, {}]. '\n        'Suggested Fix: check your n_classes argument to the estimator '\n        'and/or the shape of your label.'.format(expected_labels_dimension))\n    assert_rank = tf.compat.v1.debugging.assert_rank_at_least(\n        labels, 2, message=err_msg)\n    with tf.control_dependencies([assert_rank]):\n      static_shape = labels.shape\n      if static_shape.ndims is not None:\n        dim1 = static_shape[-1]\n        if (dim1 is not None) and (dim1 != expected_labels_dimension):\n          raise ValueError('Mismatched label shape. '\n                           'Expected labels dimension=%s.  Received %s. '\n                           'Suggested Fix:'\n                           'If your classifier expects one-hot encoding label,'\n                           'check your n_classes argument to the estimator '\n                           'and/or the shape of your label. '\n                           'Otherwise, check the shape of your label.' %\n                           (expected_labels_dimension, dim1))\n      expected_labels_shape = tf.concat(\n          [logits_shape[:-1], [expected_labels_dimension]], axis=0)\n      assert_dimension = tf.compat.v1.debugging.assert_equal(\n          expected_labels_shape,\n          labels_shape,\n          message=err_msg,\n          data=[\n              'expected_labels_shape: ', expected_labels_shape,\n              'labels_shape: ', labels_shape\n          ])\n      with tf.control_dependencies([assert_dimension]):\n        return tf.identity(labels, name=scope)\n\n\ndef _get_weights_and_check_match_logits(features,\n                                        weight_column,\n                                        logits,\n                                        allow_per_logit_weights=False):\n  \"\"\"Fetches weights from features and checks that the shape matches logits.\n\n  Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape\n  can be either:\n  * [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`.\n  * [D0, D1, ... DN, 1]\n  * [D0, D1, ... DN]: In this case, weights is reshaped into\n    [D0, D1, ... DN, 1] to work with weight broadcasting rules.\n\n  Args:\n    features: The features dict that contains weights.\n    weight_column: The weight column. If not given, this method returns 1.\n    logits: logits Tensor.\n    allow_per_logit_weights: Boolean. Whether we allow weights along the logits\n      dimension, namely shape `[D0, D1, ... DN, logits_dimension]`.\n\n  Returns:\n    Validated and reshaped weights Tensor.\n  Raises:\n    ValueError: If the weights `Tensor` cannot be cast into float.\n  \"\"\"\n  if allow_per_logit_weights:\n    err_msg = ('weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '\n               '[D0, D1, ... DN, logits_dimension]')\n  else:\n    err_msg = ('weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]')\n  with ops.name_scope(\n      None, 'weights',\n      values=tuple(six.itervalues(features)) + (logits,)) as scope:\n    # Fetch the weights.\n    if weight_column is None:\n      return 1.\n    if isinstance(weight_column, six.string_types):\n      weight_column = tf.feature_column.numeric_column(\n          key=weight_column, shape=(1,))\n    if not isinstance(\n        weight_column,\n        (tf.compat.v2.__internal__.feature_column.DenseColumn, feature_column._DenseColumn)):  # pylint: disable=protected-access\n      raise TypeError('Weight column must be either a string or _DenseColumn.'\n                      ' Given type: {}.'.format(type(weight_column)))\n    weights = weight_column._get_dense_tensor(  # pylint: disable=protected-access\n        feature_column._LazyBuilder(features))  # pylint: disable=protected-access\n    if not (weights.dtype.is_floating or weights.dtype.is_integer):\n      raise ValueError('Weight column should be castable to float. '\n                       'Given dtype: {}'.format(weights.dtype))\n    weights = tf.cast(weights, name='weights', dtype=tf.dtypes.float32)\n\n    # Validate the weights shape.\n    weights_shape = tf.compat.v1.shape(weights, name='weights_shape')\n    logits_shape = tf.compat.v1.shape(logits, name='logits_shape')\n    if (weights.shape.ndims is not None and logits.shape.ndims is not None and\n        weights.shape.ndims == logits.shape.ndims - 1):\n      assert_dimension = tf.compat.v1.debugging.assert_equal(\n          logits_shape[:-1],\n          weights_shape,\n          message=err_msg,\n          data=[\n              'logits_shape: ', logits_shape, 'weights_shape: ', weights_shape\n          ])\n      with tf.control_dependencies([assert_dimension]):\n        return tf.compat.v1.expand_dims(weights, -1, name=scope)\n    supported_weights_shape = tf.concat([logits_shape[:-1], [1]], axis=0)\n    if allow_per_logit_weights:\n      condition = tf.math.reduce_any([\n          tf.reduce_all(tf.math.equal(logits_shape, weights_shape)),\n          tf.reduce_all(tf.math.equal(supported_weights_shape, weights_shape))\n      ])\n      assert_dimension = tf.debugging.Assert(\n          condition=condition,\n          data=[\n              err_msg, 'logits_shape: ', logits_shape, 'weights_shape: ',\n              weights_shape\n          ])\n    else:\n      assert_dimension = tf.compat.v1.debugging.assert_equal(\n          supported_weights_shape,\n          weights_shape,\n          message=err_msg,\n          data=[\n              'logits_shape: ', logits_shape, 'weights_shape: ', weights_shape\n          ])\n    with tf.control_dependencies([assert_dimension]):\n      return tf.identity(weights, name=scope)\n\n\ndef _check_logits_final_dim(logits, expected_logits_dimension):\n  \"\"\"Checks that logits shape is [D0, D1, ... DN, logits_dimension].\"\"\"\n  with ops.name_scope(None, 'logits', (logits,)) as scope:\n    logits = tf.cast(logits, dtype=tf.dtypes.float32)\n    logits_shape = tf.compat.v1.shape(logits)\n    assert_rank = tf.compat.v1.debugging.assert_rank_at_least(\n        logits,\n        2,\n        data=[logits_shape],\n        message='logits shape must be [D0, D1, ... DN, logits_dimension]')\n    with tf.control_dependencies([assert_rank]):\n      static_shape = logits.shape\n      if static_shape.ndims is not None and static_shape[-1] is not None:\n        if (isinstance(expected_logits_dimension, int) and\n            static_shape[-1] != expected_logits_dimension):\n          raise ValueError(\n              'logits shape must be [D0, D1, ... DN, logits_dimension=%s], '\n              'got %s.' % (expected_logits_dimension, static_shape))\n        return logits\n      assert_dimension = tf.compat.v1.debugging.assert_equal(\n          expected_logits_dimension,\n          logits_shape[-1],\n          data=[logits_shape],\n          message=('logits shape must be [D0, D1, ... DN, '\n                   'logits_dimension=%s]' % (expected_logits_dimension,)))\n      with tf.control_dependencies([assert_dimension]):\n        return tf.identity(logits, name=scope)\n\n\ndef _validate_loss_fn_args(loss_fn):\n  \"\"\"Validates loss_fn arguments.\n\n  Required arguments: labels, logits.\n  Optional arguments: features.\n\n  Args:\n    loss_fn: The loss function.\n\n  Raises:\n    ValueError: If the signature is unexpected.\n  \"\"\"\n  loss_fn_args = function_utils.fn_args(loss_fn)\n  for required_arg in ['labels', 'logits']:\n    if required_arg not in loss_fn_args:\n      raise ValueError('loss_fn must contain argument: {}. '\n                       'Given arguments: {}'.format(required_arg, loss_fn_args))\n  invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features']))\n  if invalid_args:\n    raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))\n\n\ndef _validate_n_classes(n_classes):\n  \"\"\"Validates n_classes argument.\n\n  Required arguments: n_classes.\n\n  Args:\n    n_classes: The number of classes.\n\n  Raises:\n    ValueError: If n_classes is <= 2 and n_classes is a Python integer.\n  Returns:\n    n_classes in its original type.\n  \"\"\"\n  if isinstance(n_classes, int) and (n_classes <= 2):\n    raise ValueError('n_classes must be > 2: %s.' % n_classes)\n\n  n_classes_as_tensor = ops.convert_to_tensor(n_classes)\n  assert_n_classes = tf.compat.v1.debugging.assert_greater(\n      n_classes_as_tensor, 2, message='n_classes must be greater than 2')\n  with tf.control_dependencies([assert_n_classes]):\n    tf.no_op()\n  # Return n_classes in its original type, so that any code\n  # using the accessor logits_dimension() has the original type.\n  return n_classes\n\n\ndef _call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1):\n  \"\"\"Calls loss_fn and checks the returned shape.\n\n  Args:\n    loss_fn: The loss function.\n    labels: Processed labels Tensor.\n    logits: Logits Tensor of shape [D0, D1, ... DN, logits_dimension].\n    features: Features dict.\n    expected_loss_dim: The expected last dimension of loss Tensor.\n\n  Returns:\n    Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim].\n  \"\"\"\n  loss_fn_args = function_utils.fn_args(loss_fn)\n  kwargs = {}\n  if 'features' in loss_fn_args:\n    kwargs['features'] = features\n  with ops.name_scope(\n      None,\n      'call_loss_fn',\n      values=[labels, logits] + list(six.itervalues(features))):\n    unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)\n    logits_shape = tf.compat.v1.shape(logits, name='logits_shape')\n    expected_loss_shape = tf.concat([logits_shape[:-1], [expected_loss_dim]],\n                                    axis=0,\n                                    name='expected_loss_shape')\n    loss_shape = tf.compat.v1.shape(unweighted_loss, name='loss_shape')\n    check_loss_shape_op = tf.debugging.Assert(\n        tf.reduce_all(tf.math.equal(loss_shape, expected_loss_shape)),\n        data=[\n            'loss_fn must return Tensor of shape '\n            '[D0, D1, ... DN, {}]. '.format(expected_loss_dim),\n            'logits_shape: ', logits_shape, 'loss_shape: ', loss_shape\n        ],\n        name='check_loss_shape')\n    with tf.control_dependencies([check_loss_shape_op]):\n      return tf.identity(unweighted_loss)\n\n\ndef _indicator_labels_mean(labels, weights=None, name=None):\n  with ops.name_scope(name, 'labels_mean', (labels, weights)) as scope:\n    labels = tf.cast(labels, name='labels', dtype=tf.dtypes.float32)\n    if weights is not None:\n      weights = tf.compat.v2.__internal__.ops.broadcast_weights(weights, labels)\n    return tf.compat.v1.metrics.mean(labels, weights=weights, name=scope)\n\n\ndef _all_class_ids(logits, n_classes):\n  batch_size = tf.compat.v1.shape(logits)[0]\n  class_id_list = tf.range(n_classes)\n  return tf.tile(\n      input=tf.compat.v1.expand_dims(input=class_id_list, axis=0),\n      multiples=[batch_size, 1])\n\n\ndef _all_classes(logits, n_classes, label_vocabulary=None):\n  batch_size = tf.compat.v1.shape(logits)[0]\n  if label_vocabulary:\n    classes_list = label_vocabulary\n  else:\n    classes_list = string_ops.as_string(tf.range(n_classes))\n  return tf.tile(\n      input=tf.compat.v1.expand_dims(input=classes_list, axis=0),\n      multiples=[batch_size, 1])\n\n\ndef _classification_output(scores, n_classes, label_vocabulary=None):\n  batch_size = tf.compat.v1.shape(scores)[0]\n  if label_vocabulary:\n    export_class_list = label_vocabulary\n  else:\n    export_class_list = string_ops.as_string(tf.range(n_classes))\n  export_output_classes = tf.tile(\n      input=tf.compat.v1.expand_dims(input=export_class_list, axis=0),\n      multiples=[batch_size, 1])\n  return export_output.ClassificationOutput(\n      scores=scores,\n      # `ClassificationOutput` requires string classes.\n      classes=export_output_classes)\n\n\ndef _accuracy_baseline(labels_mean):\n  \"\"\"Return accuracy baseline based on labels mean.\n\n  This is the best the model could do by always predicting one class.\n\n  Args:\n    labels_mean: Tuple of value and update op.\n\n  Returns:\n    Tuple of value and update op.\n  \"\"\"\n  with ops.name_scope(None, 'accuracy_baseline', labels_mean):\n    value, update_op = labels_mean\n    return (tf.math.maximum(value, 1. - value, name='value'),\n            tf.math.maximum(update_op, 1 - update_op, name='update_op'))\n\n\ndef _predictions_mean(predictions, weights=None, name=None):\n  with ops.name_scope(name, 'predictions_mean',\n                      (predictions, weights)) as scope:\n    predictions = tf.cast(\n        predictions, name='predictions', dtype=tf.dtypes.float32)\n    if weights is not None:\n      weights = tf.compat.v2.__internal__.ops.broadcast_weights(weights, predictions)\n    return tf.compat.v1.metrics.mean(predictions, weights=weights, name=scope)\n\n\ndef _auc(labels, predictions, weights=None, curve='ROC', name=None):\n  with ops.name_scope(name, 'auc', (predictions, labels, weights)) as scope:\n    predictions = tf.cast(\n        predictions, name='predictions', dtype=tf.dtypes.float32)\n    if weights is not None:\n      weights = tf.compat.v2.__internal__.ops.broadcast_weights(weights, predictions)\n    return tf.compat.v1.metrics.auc(\n        labels=labels,\n        predictions=predictions,\n        weights=weights,\n        curve=curve,\n        name=scope)\n\n\ndef _accuracy_at_threshold(labels, predictions, weights, threshold, name=None):\n  with ops.name_scope(name, 'accuracy_at_%s' % threshold,\n                      (predictions, labels, weights, threshold)) as scope:\n    threshold_predictions = tf.compat.v1.to_float(\n        tf.math.greater_equal(predictions, threshold))\n    return tf.compat.v1.metrics.accuracy(\n        labels=labels,\n        predictions=threshold_predictions,\n        weights=weights,\n        name=scope)\n\n\ndef _precision_at_threshold(labels, predictions, weights, threshold, name=None):\n  with ops.name_scope(name, 'precision_at_%s' % threshold,\n                      (predictions, labels, weights, threshold)) as scope:\n    precision_tensor, update_op = tf.compat.v1.metrics.precision_at_thresholds(\n        labels=labels,\n        predictions=predictions,\n        thresholds=(threshold,),\n        weights=weights,\n        name=scope)\n    return tf.compat.v1.squeeze(precision_tensor), tf.compat.v1.squeeze(\n        update_op)\n\n\ndef _recall_at_threshold(labels, predictions, weights, threshold, name=None):\n  with ops.name_scope(name, 'recall_at_%s' % threshold,\n                      (predictions, labels, weights, threshold)) as scope:\n    precision_tensor, update_op = tf.compat.v1.metrics.recall_at_thresholds(\n        labels=labels,\n        predictions=predictions,\n        thresholds=(threshold,),\n        weights=weights,\n        name=scope)\n    return tf.compat.v1.squeeze(precision_tensor), tf.compat.v1.squeeze(\n        update_op)\n\n\ndef _multi_class_head_with_softmax_cross_entropy_loss(\n    n_classes,\n    weight_column=None,\n    label_vocabulary=None,\n    loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n    loss_fn=None,\n    name=None):\n  \"\"\"Creates a '_Head' for multi class classification.\n\n  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`.\n  In many applications, the shape is `[batch_size, n_classes]`.\n\n  `labels` must be a dense `Tensor` with shape matching `logits`, namely\n  `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string\n  `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,\n  `labels` must be an integer `Tensor` with values specifying the class index.\n\n  If `weight_column` is specified, weights must be of shape\n  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.\n\n  The loss is the weighted sum over the input dimensions. Namely, if the input\n  labels have shape `[batch_size, 1]`, the loss is the weighted sum over\n  `batch_size`.\n\n  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or\n  `(labels, logits, features)` as arguments and returns unreduced loss with\n  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with\n  shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to\n  the input labels before passing them to `loss_fn`.\n\n  Args:\n    n_classes: Number of classes, must be greater than 2 (for 2 classes, use\n      `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).\n    weight_column: A string or a `_NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example.\n    label_vocabulary: A list or tuple of strings representing possible label\n      values. If it is not given, that means labels are already encoded as an\n      integer within [0, n_classes). If given, labels must be of string type and\n      have any value in `label_vocabulary`. Note that errors will be raised if\n      `label_vocabulary` is not provided but labels are strings.\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to\n      reduce training loss over batch. Defaults to `SUM`.\n    loss_fn: Optional loss function.\n    name: name of the head. If provided, summary and metrics keys will be\n      suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops.\n\n  Returns:\n    An instance of `_Head` for multi class classification.\n\n  Raises:\n    ValueError: If `n_classes`, `label_vocabulary` or `loss_reduction` is\n      invalid.\n  \"\"\"\n  if label_vocabulary is not None and not isinstance(label_vocabulary,\n                                                     (list, tuple)):\n    raise ValueError(\n        'label_vocabulary should be a list or a tuple. Given type: {}'.format(\n            type(label_vocabulary)))\n  if (loss_reduction not in tf.compat.v1.losses.Reduction.all() or\n      loss_reduction == tf.compat.v1.losses.Reduction.NONE):\n    raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))\n  if loss_fn:\n    _validate_loss_fn_args(loss_fn)\n  return _MultiClassHeadWithSoftmaxCrossEntropyLoss(\n      n_classes=n_classes,\n      weight_column=weight_column,\n      label_vocabulary=label_vocabulary,\n      loss_reduction=loss_reduction,\n      loss_fn=loss_fn,\n      name=name)\n\n\nclass _MultiClassHeadWithSoftmaxCrossEntropyLoss(_Head):\n  \"\"\"See `_multi_class_head_with_softmax_cross_entropy_loss`.\"\"\"\n\n  def __init__(self,\n               n_classes,\n               weight_column=None,\n               label_vocabulary=None,\n               loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n               loss_fn=None,\n               name=None):\n    if n_classes is None:\n      raise ValueError('n_classes cannot be None')\n    self._n_classes = _validate_n_classes(n_classes)\n    self._weight_column = weight_column\n    self._label_vocabulary = label_vocabulary\n    self._loss_reduction = loss_reduction\n    self._loss_fn = loss_fn\n    self._name = name\n\n  @property\n  def name(self):\n    return self._name\n\n  @property\n  def logits_dimension(self):\n    return self._n_classes\n\n  def _eval_metric_ops(self, labels, class_ids, weights, unreduced_loss,\n                       regularization_loss):\n    \"\"\"Returns the Eval metric ops.\"\"\"\n    with ops.name_scope(\n        None, 'metrics',\n        (labels, class_ids, weights, unreduced_loss, regularization_loss)):\n      keys = metric_keys.MetricKeys\n      metric_ops = {\n          # Estimator already adds a metric for loss.\n          # TODO(xiejw): Any other metrics?\n          _summary_key(self._name, keys.LOSS_MEAN):\n              tf.compat.v1.metrics.mean(\n                  values=unreduced_loss, weights=weights, name=keys.LOSS_MEAN),\n          _summary_key(self._name, keys.ACCURACY):\n              tf.compat.v1.metrics.accuracy(\n                  labels=labels,\n                  predictions=class_ids,\n                  weights=weights,\n                  name=keys.ACCURACY),\n      }\n      if regularization_loss is not None:\n        metric_ops[_summary_key(self._name, keys.LOSS_REGULARIZATION)] = (\n            tf.compat.v1.metrics.mean(\n                values=regularization_loss, name=keys.LOSS_REGULARIZATION))\n    return metric_ops\n\n  def _label_ids(self, labels):\n    \"\"\"Converts labels to integer id space.\"\"\"\n    if self._label_vocabulary is None:\n      if not labels.dtype.is_integer:\n        raise ValueError(\n            'Labels dtype should be integer. Instead got {}.'.format(\n                labels.dtype))\n      label_ids = labels\n    else:\n      if labels.dtype != tf.dtypes.string:\n        raise ValueError('Labels dtype should be string if there is a '\n                         'vocabulary. Instead got {}'.format(labels.dtype))\n      label_ids = lookup_ops.index_table_from_tensor(\n          vocabulary_list=tuple(self._label_vocabulary),\n          name='class_id_lookup').lookup(labels)\n    return _assert_range(label_ids, self._n_classes)\n\n  def create_loss(self, features, mode, logits, labels):\n    \"\"\"See `Head`.\"\"\"\n    del mode  # Unused for this head.\n    logits = ops.convert_to_tensor(logits)\n    labels = _check_dense_labels_match_logits_and_reshape(\n        labels=labels, logits=logits, expected_labels_dimension=1)\n    label_ids = self._label_ids(labels)\n    if self._loss_fn:\n      unweighted_loss = _call_loss_fn(\n          loss_fn=self._loss_fn,\n          labels=label_ids,\n          logits=logits,\n          features=features,\n          expected_loss_dim=1)\n    else:\n      unweighted_loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(\n          labels=label_ids,\n          logits=logits,\n          reduction=tf.compat.v1.losses.Reduction.NONE)\n      # Restore the squeezed dim, so unweighted_loss matches the weights shape.\n      unweighted_loss = tf.compat.v1.expand_dims(unweighted_loss, axis=-1)\n    weights = _get_weights_and_check_match_logits(\n        features=features, weight_column=self._weight_column, logits=logits)\n    training_loss = tf.compat.v1.losses.compute_weighted_loss(\n        unweighted_loss, weights=weights, reduction=self._loss_reduction)\n    return LossSpec(\n        training_loss=training_loss,\n        unreduced_loss=unweighted_loss,\n        weights=weights,\n        processed_labels=label_ids)\n\n  def _create_tpu_estimator_spec(self,\n                                 features,\n                                 mode,\n                                 logits,\n                                 labels=None,\n                                 optimizer=None,\n                                 train_op_fn=None,\n                                 regularization_losses=None):\n    \"\"\"Returns a `model_fn._TPUEstimatorSpec`.\n\n    Args:\n      features: Input `dict` of `Tensor` or `SparseTensor` objects.\n      mode: Estimator's `ModeKeys`.\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.\n        For many applications, the shape is `[batch_size, logits_dimension]`.\n      labels: Labels integer or string `Tensor` with shape matching `logits`,\n        namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required\n        argument when `mode` equals `TRAIN` or `EVAL`.\n      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.\n        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which\n        updates variables and increments `global_step`.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns\n        `train_op`. Used if `optimizer` is `None`.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses. These losses are\n        usually expressed as a batch average, so for best results users need to\n        set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid\n        scaling errors.\n\n    Returns:\n      A `model_fn._TPUEstimatorSpec` instance.\n    Raises:\n      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN\n        mode, or if both are set.\n    \"\"\"\n    with ops.name_scope(self._name, 'head'):\n      logits = _check_logits_final_dim(logits, self.logits_dimension)\n\n      # Predict.\n      pred_keys = prediction_keys.PredictionKeys\n      with ops.name_scope(None, 'predictions', (logits,)):\n        all_class_ids = _all_class_ids(logits, self._n_classes)\n        all_classes = _all_classes(\n            logits, self._n_classes, label_vocabulary=self._label_vocabulary)\n        # class_ids's shape is [D0, D1, ... DN].\n        class_ids = tf.compat.v1.math.argmax(\n            logits, axis=-1, name=pred_keys.CLASS_IDS)\n        class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1)\n        if self._label_vocabulary:\n          table = lookup_ops.index_to_string_table_from_tensor(\n              vocabulary_list=self._label_vocabulary,\n              name='class_string_lookup')\n          classes = table.lookup(class_ids)\n        else:\n          classes = tf.strings.as_string(class_ids, name='str_classes')\n\n        probabilities = tf.compat.v1.nn.softmax(\n            logits, name=pred_keys.PROBABILITIES)\n        predictions = {\n            pred_keys.LOGITS: logits,\n            pred_keys.PROBABILITIES: probabilities,\n            # Expand to [batch_size, 1]\n            pred_keys.CLASS_IDS: class_ids,\n            pred_keys.CLASSES: classes,\n            pred_keys.ALL_CLASS_IDS: all_class_ids,\n            pred_keys.ALL_CLASSES: all_classes,\n        }\n      if mode == ModeKeys.PREDICT:\n        classifier_output = _classification_output(\n            scores=probabilities,\n            n_classes=self._n_classes,\n            label_vocabulary=self._label_vocabulary)\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs={\n                _DEFAULT_SERVING_KEY: classifier_output,\n                _CLASSIFY_SERVING_KEY: classifier_output,\n                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)\n            })\n\n      training_loss, unreduced_loss, weights, label_ids = self.create_loss(\n          features=features, mode=mode, logits=logits, labels=labels)\n      if regularization_losses:\n        regularization_loss = tf.math.add_n(regularization_losses)\n        regularized_training_loss = tf.math.add_n(\n            [training_loss, regularization_loss])\n      else:\n        regularization_loss = None\n        regularized_training_loss = training_loss\n      # Eval.\n      if mode == ModeKeys.EVAL:\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.EVAL,\n            predictions=predictions,\n            loss=regularized_training_loss,\n            eval_metrics=_create_eval_metrics_tuple(\n                self._eval_metric_ops, {\n                    'labels': label_ids,\n                    'class_ids': class_ids,\n                    'weights': weights,\n                    'unreduced_loss': unreduced_loss,\n                    'regularization_loss': regularization_loss\n                }))\n\n      # Train.\n      if optimizer is not None:\n        if train_op_fn is not None:\n          raise ValueError('train_op_fn and optimizer cannot both be set.')\n        train_op = optimizer.minimize(\n            regularized_training_loss,\n            global_step=tf.compat.v1.train.get_global_step())\n      elif train_op_fn is not None:\n        train_op = train_op_fn(regularized_training_loss)\n      else:\n        raise ValueError('train_op_fn and optimizer cannot both be None.')\n      train_op = _append_update_ops(train_op)\n      # Only summarize mean_loss for SUM reduction to preserve backwards\n      # compatibility. Otherwise skip it to avoid unnecessary computation.\n      if self._loss_reduction == tf.compat.v1.losses.Reduction.SUM:\n        example_weight_sum = tf.math.reduce_sum(\n            weights * tf.compat.v1.ones_like(unreduced_loss))\n        mean_loss = training_loss / example_weight_sum\n      else:\n        mean_loss = None\n    with ops.name_scope(''):\n      keys = metric_keys.MetricKeys\n      tf.compat.v1.summary.scalar(\n          _summary_key(self._name, keys.LOSS), regularized_training_loss)\n      if mean_loss is not None:\n        tf.compat.v1.summary.scalar(\n            _summary_key(self._name, keys.LOSS_MEAN), mean_loss)\n      if regularization_loss is not None:\n        tf.compat.v1.summary.scalar(\n            _summary_key(self._name, keys.LOSS_REGULARIZATION),\n            regularization_loss)\n    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n        mode=ModeKeys.TRAIN,\n        predictions=predictions,\n        loss=regularized_training_loss,\n        train_op=train_op)\n\n\ndef _binary_logistic_head_with_sigmoid_cross_entropy_loss(\n    weight_column=None,\n    thresholds=None,\n    label_vocabulary=None,\n    loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n    loss_fn=None,\n    name=None):\n  \"\"\"Creates a `_Head` for single label binary classification.\n\n  This head uses `sigmoid_cross_entropy_with_logits` loss.\n\n  The head expects `logits` with shape `[D0, D1, ... DN, 1]`.\n  In many applications, the shape is `[batch_size, 1]`.\n\n  `labels` must be a dense `Tensor` with shape matching `logits`, namely\n  `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string\n  `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,\n  `labels` must be float `Tensor` with values in the interval `[0, 1]`.\n\n  If `weight_column` is specified, weights must be of shape\n  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.\n\n  The loss is the weighted sum over the input dimensions. Namely, if the input\n  labels have shape `[batch_size, 1]`, the loss is the weighted sum over\n  `batch_size`.\n\n  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or\n  `(labels, logits, features)` as arguments and returns unreduced loss with\n  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with\n  shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to\n  the input labels before passing them to `loss_fn`.\n\n  Args:\n    weight_column: A string or a `_NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example.\n    thresholds: Iterable of floats in the range `(0, 1)`. For binary\n      classification metrics such as precision and recall, an eval metric is\n      generated for each threshold value. This threshold is applied to the\n      logistic values to determine the binary classification (i.e., above the\n      threshold is `true`, below is `false`.\n    label_vocabulary: A list or tuple of strings representing possible label\n      values. If it is not given, that means labels are already encoded within\n      [0, 1]. If given, labels must be string type and have any value in\n      `label_vocabulary`. Note that errors will be raised if `label_vocabulary`\n      is not provided but labels are strings.\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to\n      reduce training loss over batch. Defaults to `SUM`.\n    loss_fn: Optional loss function.\n    name: name of the head. If provided, summary and metrics keys will be\n      suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops.\n\n  Returns:\n    An instance of `_Head` for binary classification.\n\n  Raises:\n    ValueError: If `thresholds` contains a value outside of `(0, 1)`.\n    ValueError: If `loss_reduction` is invalid.\n    TypeError: if `label_vocabulary` has invalid type.\n  \"\"\"\n  thresholds = tuple(thresholds) if thresholds else tuple()\n  if label_vocabulary is not None and not isinstance(label_vocabulary,\n                                                     (list, tuple)):\n    raise TypeError(\n        'label_vocabulary should be a list or tuple. Given type: {}'.format(\n            type(label_vocabulary)))\n\n  for threshold in thresholds:\n    if (threshold <= 0.0) or (threshold >= 1.0):\n      raise ValueError('thresholds not in (0, 1): {}.'.format((thresholds,)))\n  if (loss_reduction not in tf.compat.v1.losses.Reduction.all() or\n      loss_reduction == tf.compat.v1.losses.Reduction.NONE):\n    raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))\n  if loss_fn:\n    _validate_loss_fn_args(loss_fn)\n  return _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(\n      weight_column=weight_column,\n      thresholds=thresholds,\n      label_vocabulary=label_vocabulary,\n      loss_reduction=loss_reduction,\n      loss_fn=loss_fn,\n      name=name)\n\n\nclass _BinaryLogisticHeadWithSigmoidCrossEntropyLoss(_Head):\n  \"\"\"See `_binary_logistic_head_with_sigmoid_cross_entropy_loss`.\"\"\"\n\n  def __init__(self,\n               weight_column=None,\n               thresholds=None,\n               label_vocabulary=None,\n               loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n               loss_fn=None,\n               name=None):\n    self._weight_column = weight_column\n    self._thresholds = tuple(thresholds) if thresholds else tuple()\n    self._label_vocabulary = label_vocabulary\n    self._loss_reduction = loss_reduction\n    self._loss_fn = loss_fn\n    self._name = name\n\n  @property\n  def name(self):\n    return self._name\n\n  @property\n  def logits_dimension(self):\n    return 1\n\n  def _eval_metric_ops(self, labels, logits, logistic, class_ids, weights,\n                       unreduced_loss, regularization_loss):\n    with ops.name_scope(None, 'metrics',\n                        (labels, logits, logistic, class_ids, weights,\n                         unreduced_loss, regularization_loss)):\n      keys = metric_keys.MetricKeys\n      labels_mean = _indicator_labels_mean(\n          labels=labels, weights=weights, name=keys.LABEL_MEAN)\n      metric_ops = {\n          # Estimator already adds a metric for loss.\n          _summary_key(self._name, keys.LOSS_MEAN):\n              tf.compat.v1.metrics.mean(\n                  values=unreduced_loss, weights=weights, name=keys.LOSS_MEAN),\n          _summary_key(self._name, keys.ACCURACY):\n              tf.compat.v1.metrics.accuracy(\n                  labels=labels,\n                  predictions=class_ids,\n                  weights=weights,\n                  name=keys.ACCURACY),\n          _summary_key(self._name, keys.PRECISION):\n              tf.compat.v1.metrics.precision(\n                  labels=labels,\n                  predictions=class_ids,\n                  weights=weights,\n                  name=keys.PRECISION),\n          _summary_key(self._name, keys.RECALL):\n              tf.compat.v1.metrics.recall(\n                  labels=labels,\n                  predictions=class_ids,\n                  weights=weights,\n                  name=keys.RECALL),\n          _summary_key(self._name, keys.PREDICTION_MEAN):\n              _predictions_mean(\n                  predictions=logistic,\n                  weights=weights,\n                  name=keys.PREDICTION_MEAN),\n          _summary_key(self._name, keys.LABEL_MEAN):\n              labels_mean,\n          _summary_key(self._name, keys.ACCURACY_BASELINE):\n              _accuracy_baseline(labels_mean),\n          _summary_key(self._name, keys.AUC):\n              _auc(\n                  labels=labels,\n                  predictions=logistic,\n                  weights=weights,\n                  name=keys.AUC),\n          _summary_key(self._name, keys.AUC_PR):\n              _auc(\n                  labels=labels,\n                  predictions=logistic,\n                  weights=weights,\n                  curve='PR',\n                  name=keys.AUC_PR)\n      }\n      if regularization_loss is not None:\n        metric_ops[_summary_key(self._name, keys.LOSS_REGULARIZATION)] = (\n            tf.compat.v1.metrics.mean(\n                values=regularization_loss, name=keys.LOSS_REGULARIZATION))\n      for threshold in self._thresholds:\n        accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold\n        metric_ops[_summary_key(self._name,\n                                accuracy_key)] = _accuracy_at_threshold(\n                                    labels=labels,\n                                    predictions=logistic,\n                                    weights=weights,\n                                    threshold=threshold,\n                                    name=accuracy_key)\n        # Precision for positive examples.\n        precision_key = keys.PRECISION_AT_THRESHOLD % threshold\n        metric_ops[_summary_key(self._name,\n                                precision_key)] = _precision_at_threshold(\n                                    labels=labels,\n                                    predictions=logistic,\n                                    weights=weights,\n                                    threshold=threshold,\n                                    name=precision_key)\n        # Recall for positive examples.\n        recall_key = keys.RECALL_AT_THRESHOLD % threshold\n        metric_ops[_summary_key(self._name, recall_key)] = _recall_at_threshold(\n            labels=labels,\n            predictions=logistic,\n            weights=weights,\n            threshold=threshold,\n            name=recall_key)\n      return metric_ops\n\n  def create_loss(self, features, mode, logits, labels):\n    \"\"\"See `Head`.\"\"\"\n    del mode  # Unused for this head.\n    logits = ops.convert_to_tensor(logits)\n    labels = _check_dense_labels_match_logits_and_reshape(\n        labels=labels, logits=logits, expected_labels_dimension=1)\n    if self._label_vocabulary is not None:\n      labels = lookup_ops.index_table_from_tensor(\n          vocabulary_list=tuple(self._label_vocabulary),\n          name='class_id_lookup').lookup(labels)\n    labels = tf.cast(labels, dtype=tf.dtypes.float32)\n    labels = _assert_range(labels, n_classes=2)\n    if self._loss_fn:\n      unweighted_loss = _call_loss_fn(\n          loss_fn=self._loss_fn,\n          labels=labels,\n          logits=logits,\n          features=features,\n          expected_loss_dim=1)\n    else:\n      unweighted_loss = tf.compat.v1.nn.sigmoid_cross_entropy_with_logits(\n          labels=labels, logits=logits)\n    weights = _get_weights_and_check_match_logits(\n        features=features, weight_column=self._weight_column, logits=logits)\n    training_loss = tf.compat.v1.losses.compute_weighted_loss(\n        unweighted_loss, weights=weights, reduction=self._loss_reduction)\n    return LossSpec(\n        training_loss=training_loss,\n        unreduced_loss=unweighted_loss,\n        weights=weights,\n        processed_labels=labels)\n\n  def _create_tpu_estimator_spec(self,\n                                 features,\n                                 mode,\n                                 logits,\n                                 labels=None,\n                                 optimizer=None,\n                                 train_op_fn=None,\n                                 regularization_losses=None):\n    \"\"\"Returns an `EstimatorSpec`.\n\n    Args:\n      features: Input `dict` of `Tensor` or `SparseTensor` objects.\n      mode: Estimator's `ModeKeys`.\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many\n        applications, the shape is `[batch_size, 1]`.\n      labels: Labels integer or string `Tensor` with shape matching `logits`,\n        namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required\n        argument when `mode` equals `TRAIN` or `EVAL`.\n      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.\n        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which\n        updates variables and increments `global_step`.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns\n        `train_op`. Used if `optimizer` is `None`.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses. These losses are\n        usually expressed as a batch average, so for best results users need to\n        set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid\n        scaling errors.\n\n    Returns:\n      `EstimatorSpec`.\n    Raises:\n      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN\n        mode, or if both are set.\n    \"\"\"\n    # Predict.\n    with ops.name_scope(self._name, 'head'):\n      with ops.name_scope(None, 'predictions', (logits,)):\n        pred_keys = prediction_keys.PredictionKeys\n        logits = _check_logits_final_dim(logits, self.logits_dimension)\n        logistic = tf.math.sigmoid(logits, name=pred_keys.LOGISTIC)\n        two_class_logits = tf.concat((tf.compat.v1.zeros_like(logits), logits),\n                                     axis=-1,\n                                     name='two_class_logits')\n        probabilities = tf.compat.v1.nn.softmax(\n            two_class_logits, name=pred_keys.PROBABILITIES)\n        class_ids = tf.compat.v1.math.argmax(\n            two_class_logits, axis=-1, name=pred_keys.CLASS_IDS)\n        class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1)\n        all_class_ids = _all_class_ids(logits, n_classes=2)\n        all_classes = _all_classes(\n            logits, n_classes=2, label_vocabulary=self._label_vocabulary)\n\n        if self._label_vocabulary:\n          table = lookup_ops.index_to_string_table_from_tensor(\n              vocabulary_list=self._label_vocabulary,\n              name='class_string_lookup')\n          classes = table.lookup(class_ids)\n        else:\n          classes = string_ops.as_string(class_ids, name='str_classes')\n        predictions = {\n            pred_keys.LOGITS: logits,\n            pred_keys.LOGISTIC: logistic,\n            pred_keys.PROBABILITIES: probabilities,\n            pred_keys.CLASS_IDS: class_ids,\n            pred_keys.CLASSES: classes,\n            pred_keys.ALL_CLASS_IDS: all_class_ids,\n            pred_keys.ALL_CLASSES: all_classes,\n        }\n      if mode == ModeKeys.PREDICT:\n        classifier_output = _classification_output(\n            scores=probabilities,\n            n_classes=2,\n            label_vocabulary=self._label_vocabulary)\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs={\n                _DEFAULT_SERVING_KEY: classifier_output,\n                _CLASSIFY_SERVING_KEY: classifier_output,\n                _REGRESS_SERVING_KEY: export_output.RegressionOutput(\n                    value=logistic),\n                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)\n            })\n\n      (training_loss, unreduced_loss, weights, processed_labels) = (\n          self.create_loss(\n              features=features, mode=mode, logits=logits, labels=labels))\n      if regularization_losses:\n        regularization_loss = tf.math.add_n(regularization_losses)\n        regularized_training_loss = tf.math.add_n(\n            [training_loss, regularization_loss])\n      else:\n        regularization_loss = None\n        regularized_training_loss = training_loss\n\n      # Eval.\n      if mode == ModeKeys.EVAL:\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.EVAL,\n            predictions=predictions,\n            loss=regularized_training_loss,\n            eval_metrics=_create_eval_metrics_tuple(\n                self._eval_metric_ops, {\n                    'labels': processed_labels,\n                    'logits': logits,\n                    'logistic': logistic,\n                    'class_ids': class_ids,\n                    'weights': weights,\n                    'unreduced_loss': unreduced_loss,\n                    'regularization_loss': regularization_loss\n                }))\n\n      # Train.\n      if optimizer is not None:\n        if train_op_fn is not None:\n          raise ValueError('train_op_fn and optimizer cannot both be set.')\n        train_op = optimizer.minimize(\n            regularized_training_loss,\n            global_step=tf.compat.v1.train.get_global_step())\n      elif train_op_fn is not None:\n        train_op = train_op_fn(regularized_training_loss)\n      else:\n        raise ValueError('train_op_fn and optimizer cannot both be None.')\n      train_op = _append_update_ops(train_op)\n      # Only summarize mean_loss for SUM reduction to preserve backwards\n      # compatibility. Otherwise skip it to avoid unnecessary computation.\n      if self._loss_reduction == tf.compat.v1.losses.Reduction.SUM:\n        example_weight_sum = tf.math.reduce_sum(\n            weights * tf.compat.v1.ones_like(unreduced_loss))\n        mean_loss = training_loss / example_weight_sum\n      else:\n        mean_loss = None\n    with ops.name_scope(''):\n      keys = metric_keys.MetricKeys\n      tf.compat.v1.summary.scalar(\n          _summary_key(self._name, keys.LOSS), regularized_training_loss)\n      if mean_loss is not None:\n        tf.compat.v1.summary.scalar(\n            _summary_key(self._name, keys.LOSS_MEAN), mean_loss)\n      if regularization_loss is not None:\n        tf.compat.v1.summary.scalar(\n            _summary_key(self._name, keys.LOSS_REGULARIZATION),\n            regularization_loss)\n    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n        mode=ModeKeys.TRAIN,\n        predictions=predictions,\n        loss=regularized_training_loss,\n        train_op=train_op)\n\n\ndef _regression_head(weight_column=None,\n                     label_dimension=1,\n                     loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n                     loss_fn=None,\n                     inverse_link_fn=None,\n                     name=None):\n  \"\"\"Creates a `_Head` for regression using the `mean_squared_error` loss.\n\n  The loss is the weighted sum over all input dimensions. Namely, if the input\n  labels have shape `[batch_size, label_dimension]`, the loss is the weighted\n  sum over both `batch_size` and `label_dimension`.\n\n  The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`.\n  In many applications, the shape is `[batch_size, label_dimension]`.\n\n  The `labels` shape must match `logits`, namely\n  `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape\n  `[D0, D1, ... DN]` is also supported.\n\n  If `weight_column` is specified, weights must be of shape\n  `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or\n  `[D0, D1, ... DN, label_dimension]`.\n\n  Supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or\n  `(labels, logits, features)` as arguments and returns unreduced loss with\n  shape `[D0, D1, ... DN, label_dimension]`.\n\n  Also supports custom `inverse_link_fn`, also known as 'mean function'.\n  `inverse_link_fn` takes `logits` as argument and returns predicted values.\n  This function is the inverse of the link function defined in\n  https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function\n  Namely, for poisson regression, set `inverse_link_fn=tf.exp`.\n\n  Args:\n    weight_column: A string or a `_NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example.\n    label_dimension: Number of regression labels per example. This is the size\n      of the last dimension of the labels `Tensor` (typically, this has shape\n      `[batch_size, label_dimension]`).\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to\n      reduce training loss over batch. Defaults to `SUM`.\n    loss_fn: Optional loss function. Defaults to `mean_squared_error`.\n    inverse_link_fn: Optional inverse link function, also known as 'mean\n      function'. Defaults to identity.\n    name: name of the head. If provided, summary and metrics keys will be\n      suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops.\n\n  Returns:\n    An instance of `_Head` for linear regression.\n\n  Raises:\n    ValueError: If `label_dimension` or `loss_reduction` is invalid.\n  \"\"\"\n  if (loss_reduction not in tf.compat.v1.losses.Reduction.all() or\n      loss_reduction == tf.compat.v1.losses.Reduction.NONE):\n    raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))\n  if loss_fn:\n    _validate_loss_fn_args(loss_fn)\n  return _RegressionHeadWithMeanSquaredErrorLoss(\n      weight_column=weight_column,\n      label_dimension=label_dimension,\n      loss_reduction=loss_reduction,\n      loss_fn=loss_fn,\n      inverse_link_fn=inverse_link_fn,\n      name=name)\n\n\nclass _RegressionHeadWithMeanSquaredErrorLoss(_Head):\n  \"\"\"`Head` for regression using the mean squared loss.\"\"\"\n\n  def __init__(self,\n               label_dimension,\n               weight_column=None,\n               loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n               loss_fn=None,\n               inverse_link_fn=None,\n               name=None):\n    \"\"\"`Head` for regression.\"\"\"\n    if label_dimension < 1:\n      raise ValueError('Invalid label_dimension %s.' % label_dimension)\n    self._logits_dimension = label_dimension\n    self._weight_column = weight_column\n    self._loss_reduction = loss_reduction\n    self._loss_fn = loss_fn\n    self._inverse_link_fn = inverse_link_fn\n    self._name = name\n\n  @property\n  def name(self):\n    return self._name\n\n  @property\n  def logits_dimension(self):\n    return self._logits_dimension\n\n  def create_loss(self, features, mode, logits, labels):\n    \"\"\"See `Head`.\"\"\"\n    del mode  # Unused for this head.\n    logits = ops.convert_to_tensor(logits)\n    labels = _check_dense_labels_match_logits_and_reshape(\n        labels=labels,\n        logits=logits,\n        expected_labels_dimension=self._logits_dimension)\n    labels = tf.cast(labels, dtype=tf.dtypes.float32)\n    if self._loss_fn:\n      unweighted_loss = _call_loss_fn(\n          loss_fn=self._loss_fn,\n          labels=labels,\n          logits=logits,\n          features=features,\n          expected_loss_dim=self._logits_dimension)\n    else:\n      unweighted_loss = tf.compat.v1.losses.mean_squared_error(\n          labels=labels,\n          predictions=logits,\n          reduction=tf.compat.v1.losses.Reduction.NONE)\n    weights = _get_weights_and_check_match_logits(\n        features=features,\n        weight_column=self._weight_column,\n        logits=logits,\n        allow_per_logit_weights=True)\n    training_loss = tf.compat.v1.losses.compute_weighted_loss(\n        unweighted_loss, weights=weights, reduction=self._loss_reduction)\n    return LossSpec(\n        training_loss=training_loss,\n        unreduced_loss=unweighted_loss,\n        weights=weights,\n        processed_labels=labels)\n\n  def _eval_metric_ops(self, predicted_value, labels, weights, unreduced_loss,\n                       regularization_loss):\n    \"\"\"Returns the Eval metric ops.\"\"\"\n    keys = metric_keys.MetricKeys\n    # Estimator already adds a metric for loss.\n    eval_metric_ops = {\n        _summary_key(self._name, keys.LOSS_MEAN):\n            tf.compat.v1.metrics.mean(values=unreduced_loss, weights=weights),\n        _summary_key(self._name, keys.PREDICTION_MEAN):\n            _predictions_mean(\n                predictions=predicted_value,\n                weights=weights,\n                name=keys.PREDICTION_MEAN),\n        _summary_key(self._name, keys.LABEL_MEAN):\n            tf.compat.v1.metrics.mean(values=labels, weights=weights)\n    }\n    if regularization_loss is not None:\n      regularization_loss_key = _summary_key(self._name,\n                                             keys.LOSS_REGULARIZATION)\n      eval_metric_ops[regularization_loss_key] = tf.compat.v1.metrics.mean(\n          values=regularization_loss, name=keys.LOSS_REGULARIZATION)\n    return eval_metric_ops\n\n  def _create_tpu_estimator_spec(self,\n                                 features,\n                                 mode,\n                                 logits,\n                                 labels=None,\n                                 optimizer=None,\n                                 train_op_fn=None,\n                                 regularization_losses=None):\n    \"\"\"Returns an `EstimatorSpec`.\n\n    Args:\n      features: Input `dict` of `Tensor` or `SparseTensor` objects.\n      mode: Estimator's `ModeKeys`.\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.\n        For many applications, the shape is `[batch_size, logits_dimension]`.\n      labels: Labels `Tensor` with shape matching `logits`, namely `[D0, D1, ...\n        DN, logits_dimension]`. When `logits_dimension=1`, shape `[D0, D1, ...\n        DN]` is also supported. `labels` is required argument when `mode` equals\n        `TRAIN` or `EVAL`.\n      optimizer: `Optimizer` instance to optimize the loss in TRAIN mode.\n        Namely, sets `train_op = optimizer.minimize(loss, global_step)`, which\n        updates variables and increments `global_step`.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns\n        `train_op`. Used if `optimizer` is `None`.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses. These losses are\n        usually expressed as a batch average, so for best results users need to\n        set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid\n        scaling errors.\n\n    Returns:\n      A `model_fn._TPUEstimatorSpec` instance.\n    Raises:\n      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN\n        mode, or if both are set.\n    \"\"\"\n    # Predict.\n    with ops.name_scope(self._name, 'head'):\n      logits = _check_logits_final_dim(logits, self._logits_dimension)\n      if self._inverse_link_fn:\n        predicted_value = self._inverse_link_fn(logits)\n        predictions = {\n            prediction_keys.PredictionKeys.PREDICTIONS: predicted_value,\n            prediction_keys.PredictionKeys.LOGITS: logits,\n        }\n      else:\n        predicted_value = logits\n        predictions = {\n            prediction_keys.PredictionKeys.PREDICTIONS: predicted_value\n        }\n      if mode == ModeKeys.PREDICT:\n        regression_output = export_output.RegressionOutput(\n            value=predicted_value)\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs={\n                _DEFAULT_SERVING_KEY: regression_output,\n                _REGRESS_SERVING_KEY: regression_output,\n                _PREDICT_SERVING_KEY: export_output.PredictOutput(predictions)\n            })\n\n      training_loss, unreduced_loss, weights, _ = self.create_loss(\n          features=features, mode=mode, logits=logits, labels=labels)\n      if regularization_losses:\n        regularization_loss = tf.math.add_n(regularization_losses)\n        regularized_training_loss = tf.math.add_n(\n            [training_loss, regularization_loss])\n      else:\n        regularization_loss = None\n        regularized_training_loss = training_loss\n\n      # Eval.\n      if mode == ModeKeys.EVAL:\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.EVAL,\n            predictions=predictions,\n            loss=regularized_training_loss,\n            eval_metrics=_create_eval_metrics_tuple(\n                self._eval_metric_ops, {\n                    'predicted_value': predicted_value,\n                    'labels': labels,\n                    'weights': weights,\n                    'unreduced_loss': unreduced_loss,\n                    'regularization_loss': regularization_loss,\n                }))\n\n      # Train.\n      if optimizer is not None:\n        if train_op_fn is not None:\n          raise ValueError('train_op_fn and optimizer cannot both be set.')\n        train_op = optimizer.minimize(\n            regularized_training_loss,\n            global_step=tf.compat.v1.train.get_global_step())\n      elif train_op_fn is not None:\n        train_op = train_op_fn(regularized_training_loss)\n      else:\n        raise ValueError('train_op_fn and optimizer cannot both be None.')\n      train_op = _append_update_ops(train_op)\n      # Only summarize mean_loss for SUM reduction to preserve backwards\n      # compatibility. Otherwise skip it to avoid unnecessary computation.\n      if self._loss_reduction == tf.compat.v1.losses.Reduction.SUM:\n        example_weight_sum = tf.math.reduce_sum(\n            weights * tf.compat.v1.ones_like(unreduced_loss))\n        mean_loss = training_loss / example_weight_sum\n      else:\n        mean_loss = None\n    with ops.name_scope(''):\n      keys = metric_keys.MetricKeys\n      tf.compat.v1.summary.scalar(\n          _summary_key(self._name, keys.LOSS), regularized_training_loss)\n      if mean_loss is not None:\n        tf.compat.v1.summary.scalar(\n            _summary_key(self._name, keys.LOSS_MEAN), mean_loss)\n      if regularization_loss is not None:\n        tf.compat.v1.summary.scalar(\n            _summary_key(self._name, keys.LOSS_REGULARIZATION),\n            regularization_loss)\n    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n        mode=ModeKeys.TRAIN,\n        predictions=predictions,\n        loss=regularized_training_loss,\n        train_op=train_op)\n\n\ndef _append_update_ops(train_op):\n  \"\"\"Returns `train_op` appending `UPDATE_OPS` collection if present.\"\"\"\n  update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)\n  if update_ops:\n    return tf.group(train_op, *update_ops)\n  return train_op\n\n\ndef _assert_range(labels, n_classes, message=None):\n  with ops.name_scope(None, 'assert_range', (labels,)):\n    assert_less = tf.compat.v1.debugging.assert_less_equal(\n        labels,\n        ops.convert_to_tensor(n_classes - 1, dtype=labels.dtype),\n        message=message or 'Labels must <= n_classes - 1')\n    assert_greater = tf.compat.v1.debugging.assert_non_negative(\n        labels, message=message or 'Labels must >= 0')\n    with tf.control_dependencies((assert_less, assert_greater)):\n      return tf.identity(labels)\n\n\ndef _binary_logistic_or_multi_class_head(n_classes, weight_column,\n                                         label_vocabulary, loss_reduction):\n  \"\"\"Creates either binary or multi-class head.\n\n  Args:\n    n_classes: Number of label classes.\n    weight_column: A string or a `_NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example. If it is a string, it is\n      used as a key to fetch weight tensor from the `features`. If it is a\n      `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n      weight_column.normalizer_fn is applied on it to get weight tensor.\n    label_vocabulary: A list of strings represents possible label values. If\n      given, labels must be string type and have any value in\n      `label_vocabulary`. If it is not given, that means labels are already\n      encoded as integer or float within [0, 1] for `n_classes=2` and encoded as\n      integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there\n      will be errors if vocabulary is not provided and labels are string.\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to\n      reduce training loss over batch. Defaults to `SUM`.\n\n  Returns:\n    `head._Head` instance.\n  \"\"\"\n  if n_classes == 2:\n    head = _binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column=weight_column,\n        label_vocabulary=label_vocabulary,\n        loss_reduction=loss_reduction)\n  else:\n    head = _multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes,\n        weight_column=weight_column,\n        label_vocabulary=label_vocabulary,\n        loss_reduction=loss_reduction)\n  return head\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/head_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for head.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.canned.v1 import dnn_testing_utils_v1\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n_DEFAULT_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n\n\ndef _initialize_variables(test_case, scaffold):\n  scaffold.finalize()\n  test_case.assertIsNone(scaffold.init_feed_dict)\n  test_case.assertIsNone(scaffold.init_fn)\n  scaffold.init_op.run()\n  scaffold.ready_for_local_init_op.eval()\n  scaffold.local_init_op.run()\n  scaffold.ready_op.eval()\n  test_case.assertIsNotNone(scaffold.saver)\n\n\ndef _assert_simple_summaries(test_case,\n                             expected_summaries,\n                             summary_str,\n                             tol=1e-6):\n  \"\"\"Assert summary the specified simple values.\n\n  Args:\n    test_case: test case.\n    expected_summaries: Dict of expected tags and simple values.\n    summary_str: Serialized `summary_pb2.Summary`.\n    tol: Tolerance for relative and absolute.\n  \"\"\"\n  summary = tf.compat.v1.summary.Summary()\n  summary.ParseFromString(summary_str)\n  test_case.assertAllClose(\n      expected_summaries, {v.tag: v.simple_value for v in summary.value},\n      rtol=tol,\n      atol=tol)\n\n\ndef _assert_no_hooks(test_case, spec):\n  test_case.assertAllEqual([], spec.training_chief_hooks)\n  test_case.assertAllEqual([], spec.training_hooks)\n\n\ndef _sigmoid(logits):\n  return 1 / (1 + np.exp(-logits))\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass CreateEstimatorSpecTest(tf.test.TestCase):\n\n  class _HeadWithTPUSupport(head_lib._Head):\n    \"\"\"Head that overrides _create_tpu_estimator_spec.\"\"\"\n\n    def name(self):\n      return 'HeadWithTPUSupport'\n\n    def logits_dimension(self):\n      return None\n\n    def create_loss(self, features, mode, logits, labels):\n      return None\n\n    def _create_tpu_estimator_spec(self,\n                                   features,\n                                   mode,\n                                   logits,\n                                   labels=None,\n                                   optimizer=None,\n                                   train_op_fn=None,\n                                   regularization_losses=None):\n      return model_fn._TPUEstimatorSpec(\n          mode=ModeKeys.EVAL, loss=tf.constant(0.0, dtype=tf.dtypes.float32))\n\n  class _HeadWithOutTPUSupport(head_lib._Head):\n    \"\"\"Head that overrides create_estimator_spec.\"\"\"\n\n    def name(self):\n      return 'HeadWithOutTPUSupport'\n\n    def logits_dimension(self):\n      return None\n\n    def create_loss(self, features, mode, logits, labels):\n      return None\n\n    def create_estimator_spec(self,\n                              features,\n                              mode,\n                              logits,\n                              labels=None,\n                              optimizer=None,\n                              train_op_fn=None,\n                              regularization_losses=None):\n      return model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL, loss=tf.constant(0.0, dtype=tf.dtypes.float32))\n\n  class _InvalidHead(head_lib._Head):\n    \"\"\"Head that overrides neither estimator_spec functions.\"\"\"\n\n    def name(self):\n      return 'InvalidHead'\n\n    def logits_dimension(self):\n      return None\n\n    def create_loss(self, features, mode, logits, labels):\n      return None\n\n  def test_head_override_tpu_estimator_spec(self):\n    \"\"\"Test for `_Head` that overrides _create_tpu_estimator_spec.\"\"\"\n    head = self._HeadWithTPUSupport()\n\n    tpu_spec = head._create_tpu_estimator_spec(\n        features=None, mode=None, logits=None)\n    self.assertTrue(isinstance(tpu_spec, model_fn._TPUEstimatorSpec))\n    est_spec = head.create_estimator_spec(features=None, mode=None, logits=None)\n    self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))\n\n  def test_head_override_estimator_spec(self):\n    \"\"\"Test for `_Head` that overrides create_estimator_spec.\"\"\"\n    head = self._HeadWithOutTPUSupport()\n\n    with self.assertRaisesRegexp(\n        NotImplementedError,\n        'TPUEstimatorSpec not available for this model head.'):\n      _ = head._create_tpu_estimator_spec(features=None, mode=None, logits=None)\n    est_spec = head.create_estimator_spec(features=None, mode=None, logits=None)\n    self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))\n\n  def test_invalid_head_class(self):\n    head = self._InvalidHead()\n\n    with self.assertRaisesRegexp(\n        NotImplementedError,\n        'TPUEstimatorSpec not available for this model head.'):\n      _ = head._create_tpu_estimator_spec(features=None, mode=None, logits=None)\n    with self.assertRaisesRegexp(\n        NotImplementedError,\n        r'Subclasses of _Head must implement `create_estimator_spec\\(\\)` or '\n        r'_create_tpu_estimator_spec\\(\\).'):\n      _ = head.create_estimator_spec(features=None, mode=None, logits=None)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass MultiClassHeadWithSoftmaxCrossEntropyLoss(tf.test.TestCase):\n\n  def setUp(self):\n    tf.compat.v1.reset_default_graph()\n\n  def test_n_classes_is_none(self):\n    with self.assertRaisesRegexp(ValueError, 'n_classes cannot be None'):\n      head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes=None)\n\n  def test_n_classes_is_2(self):\n    with self.assertRaisesRegexp(ValueError, 'n_classes must be > 2'):\n      head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes=2)\n\n  def test_invalid_loss_reduction(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'Invalid loss_reduction: invalid_loss_reduction'):\n      head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n          n_classes=3, loss_reduction='invalid_loss_reduction')\n    with self.assertRaisesRegexp(ValueError, r'Invalid loss_reduction: none'):\n      head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n          n_classes=3, loss_reduction=tf.compat.v1.losses.Reduction.NONE)\n\n  def test_loss_fn_arg_labels_missing(self):\n\n    def _loss_fn(logits):\n      del logits  # Unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: labels\\. '\n        r'Given arguments: \\(\\'logits\\',\\)'):\n      head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n          n_classes=3, loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_logits_missing(self):\n\n    def _loss_fn(labels):\n      del labels  # unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: logits\\. '\n        r'Given arguments: \\(\\'labels\\',\\)'):\n      head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n          n_classes=3, loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_features_ok(self):\n\n    def _loss_fn(labels, logits, features):\n      del labels, logits, features  # Unused\n\n    head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_invalid(self):\n\n    def _loss_fn(labels, logits, name=None):\n      del labels, logits, name  # Unused\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'loss_fn has unexpected args: \\[\\'name\\'\\]'):\n      head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n          n_classes=3, loss_fn=_loss_fn)\n\n  def test_invalid_logits_shape(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 3).\n    logits_2x2 = np.array(((45., 44.), (41., 42.),))\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      head.create_estimator_spec(\n          features={'x': np.array(((30.,), (42.,),))},\n          mode=ModeKeys.PREDICT,\n          logits=logits_2x2)\n\n    # Dynamic shape.\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((30.,), (42.,),))},\n        mode=ModeKeys.PREDICT,\n        logits=logits_placeholder)\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval(\n            {logits_placeholder: logits_2x2})\n\n  def test_invalid_labels_shape(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 3).\n    # Labels should be shape (batch_size, 1).\n    labels_2x2 = np.array(((45, 44), (41, 42),), dtype=int)\n    logits_2x3 = np.array(((1., 2., 3.), (1., 2., 3.),))\n    features = {'x': np.array(((42.,),))}\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):\n      head.create_loss(\n          features=features,\n          mode=ModeKeys.EVAL,\n          logits=logits_2x3,\n          labels=labels_2x2)\n\n    # Dynamic shape.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.int64)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.create_loss(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits_placeholder,\n        labels=labels_placeholder)[0]\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 1\\] \\[labels_shape: \\] \\[2 2\\]'):\n        training_loss.eval({\n            logits_placeholder: logits_2x3,\n            labels_placeholder: labels_2x2\n        })\n\n  def test_invalid_labels_type(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 3).\n    # Labels should be shape (batch_size, 1).\n    labels_2x1 = np.array(((1.,), (1.,),))\n    logits_2x3 = np.array(((1., 2., 3.), (1., 2., 3.),))\n    features = {'x': np.array(((42.,),))}\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'Labels dtype'):\n      head.create_loss(\n          features=features,\n          mode=ModeKeys.EVAL,\n          logits=logits_2x3,\n          labels=labels_2x1)\n\n    # Dynamic shape.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    with self.assertRaisesRegexp(ValueError, 'Labels dtype'):\n      head.create_loss(\n          features=features,\n          mode=ModeKeys.EVAL,\n          logits=logits_placeholder,\n          labels=labels_placeholder)\n\n  def test_invalid_labels_values(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    labels_2x1_with_large_id = np.array(((45,), (1,),), dtype=int)\n    labels_2x1_with_negative_id = np.array(((-5,), (1,),), dtype=int)\n    logits_2x3 = np.array(((1., 2., 4.), (1., 2., 3.),))\n\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.int64)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.create_loss(\n        features={'x': np.array(((42.,),))},\n        mode=ModeKeys.EVAL,\n        logits=logits_placeholder,\n        labels=labels_placeholder)[0]\n    with self.cached_session():\n      with self.assertRaisesOpError('Labels must <= n_classes - 1'):\n        training_loss.eval({\n            labels_placeholder: labels_2x1_with_large_id,\n            logits_placeholder: logits_2x3\n        })\n\n    with self.cached_session():\n      with self.assertRaisesOpError('Labels must >= 0'):\n        training_loss.eval({\n            labels_placeholder: labels_2x1_with_negative_id,\n            logits_placeholder: logits_2x3\n        })\n\n  def test_invalid_labels_sparse_tensor(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    labels_2x1 = tf.sparse.SparseTensor(\n        values=['english', 'italian'],\n        indices=[[0, 0], [1, 0]],\n        dense_shape=[2, 1])\n    logits_2x3 = np.array(((1., 2., 4.), (1., 2., 3.),))\n\n    with self.assertRaisesRegexp(ValueError,\n                                 'SparseTensor labels are not supported.'):\n      head.create_loss(\n          features={'x': np.array(((42.,),))},\n          mode=ModeKeys.EVAL,\n          logits=logits_2x3,\n          labels=labels_2x1)\n\n  def test_incompatible_labels_shape(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 3).\n    # Labels should be shape (batch_size, 1).\n    # Here batch sizes are different.\n    values_3x1 = np.array(((1,), (1,), (1,),))\n    values_2x3 = np.array(((1., 2., 3.), (1., 2., 3.),))\n    features = {'x': values_2x3}\n\n    # Static shape.\n    with self.assertRaisesRegexp(\n        ValueError,\n        r'Shape mismatch: The shape of labels \\(received \\(3,\\)\\) should equal '\n        r'the shape of logits except for the last dimension '\n        r'\\(received \\(2, 3\\)\\)\\.'):\n      head.create_loss(\n          features=features,\n          mode=ModeKeys.EVAL,\n          logits=values_2x3,\n          labels=values_3x1)\n\n    # Dynamic shape.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.int64)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.create_loss(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits_placeholder,\n        labels=labels_placeholder)[0]\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 1\\] \\[labels_shape: \\] \\[3 1\\]'):\n        training_loss.eval({\n            labels_placeholder: values_3x1,\n            logits_placeholder: values_2x3\n        })\n\n  def test_name(self):\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, name='foo')\n    self.assertEqual('foo', head.name)\n\n  def test_predict(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    logits = [[1., 0., 0.], [0., 0., 1.]]\n    expected_probabilities = [[0.576117, 0.2119416, 0.2119416],\n                              [0.2119416, 0.2119416, 0.576117]]\n    expected_class_ids = [[0], [2]]\n    expected_all_class_ids = [[0, 1, 2]] * 2\n    expected_classes = [[b'0'], [b'2']]\n    expected_all_classes = [[b'0', b'1', b'2']] * 2\n    expected_export_classes = [[b'0', b'1', b'2']] * 2\n\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n\n    self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict', 'classification'),\n                          spec.export_outputs.keys())\n\n    # Assert predictions and export_outputs.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(logits,\n                          predictions[prediction_keys.PredictionKeys.LOGITS])\n      self.assertAllClose(\n          expected_probabilities,\n          predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n      self.assertAllClose(expected_class_ids,\n                          predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n      self.assertAllEqual(expected_classes,\n                          predictions[prediction_keys.PredictionKeys.CLASSES])\n      self.assertAllClose(\n          expected_all_class_ids,\n          predictions[prediction_keys.PredictionKeys.ALL_CLASS_IDS])\n      self.assertAllEqual(\n          expected_all_classes,\n          predictions[prediction_keys.PredictionKeys.ALL_CLASSES])\n\n      self.assertAllClose(\n          expected_probabilities,\n          sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes))\n\n  def test_predict_with_tensor_n_classes(self):\n    n_classes = tf.constant(3, dtype=tf.dtypes.int32)\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    logits = [[1., 0., 0.], [0., 0., 1.]]\n    expected_probabilities = [[0.576117, 0.2119416, 0.2119416],\n                              [0.2119416, 0.2119416, 0.576117]]\n    expected_class_ids = [[0], [2]]\n    expected_all_class_ids = [[0, 1, 2]] * 2\n    expected_classes = [[b'0'], [b'2']]\n    expected_all_classes = [[b'0', b'1', b'2']] * 2\n    expected_export_classes = [[b'0', b'1', b'2']] * 2\n\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n\n    self.assertItemsEqual((_DEFAULT_SERVING_KEY, 'predict', 'classification'),\n                          spec.export_outputs.keys())\n\n    # Assert predictions and export_outputs.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(logits,\n                          predictions[prediction_keys.PredictionKeys.LOGITS])\n      self.assertAllClose(\n          expected_probabilities,\n          predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n      self.assertAllClose(expected_class_ids,\n                          predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n      self.assertAllEqual(expected_classes,\n                          predictions[prediction_keys.PredictionKeys.CLASSES])\n      self.assertAllClose(\n          expected_all_class_ids,\n          predictions[prediction_keys.PredictionKeys.ALL_CLASS_IDS])\n      self.assertAllEqual(\n          expected_all_classes,\n          predictions[prediction_keys.PredictionKeys.ALL_CLASSES])\n\n      self.assertAllClose(\n          expected_probabilities,\n          sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes))\n\n  def test_predict_with_vocabulary_list(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n\n    logits = [[1., 0., 0.], [0., 0., 1.]]\n    expected_classes = [[b'aang'], [b'zuko']]\n    expected_export_classes = [[b'aang', b'iroh', b'zuko']] * 2\n\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertAllEqual(\n          expected_classes,\n          sess.run(spec.predictions[prediction_keys.PredictionKeys.CLASSES]))\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes))\n\n  def test_weight_should_not_impact_prediction(self):\n    n_classes = 3\n    logits = [[1., 0., 0.], [0., 0., 1.]]\n    expected_probabilities = [[0.576117, 0.2119416, 0.2119416],\n                              [0.2119416, 0.2119416, 0.576117]]\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, weight_column='label_weights')\n\n    weights_2x1 = [[1.], [2.]]\n    spec = head.create_estimator_spec(\n        features={\n            'x': np.array(((42,),), dtype=np.int32),\n            'label_weights': weights_2x1,\n        },\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(logits,\n                          predictions[prediction_keys.PredictionKeys.LOGITS])\n      self.assertAllClose(\n          expected_probabilities,\n          predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n\n  def test_eval_create_loss(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = cross_entropy(labels, logits) = [10, 0].\n    expected_training_loss = 10.\n    # Create loss.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n\n  def test_eval_create_loss_loss_fn(self):\n    \"\"\"Tests head.create_loss for eval mode and custom loss_fn.\"\"\"\n    loss = np.array([[1.], [2.]], dtype=np.float32)\n    logits_input = np.array([[-10., 10., 0.], [-15., 10., 0]], dtype=np.float32)\n    labels_input = np.array([[1], [2]], dtype=np.int64)\n\n    def _loss_fn(labels, logits):\n      check_labels = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(labels, labels_input)), data=[labels])\n      check_logits = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(logits, logits_input)), data=[logits])\n      with tf.control_dependencies([check_labels, check_logits]):\n        return tf.constant(loss)\n\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, loss_fn=_loss_fn)\n\n    actual_training_loss = head.create_loss(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.EVAL,\n        logits=logits_input,\n        labels=labels_input)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(np.sum(loss), actual_training_loss.eval())\n\n  def test_eval_create_loss_loss_fn_wrong_shape(self):\n    \"\"\"Tests custom loss_fn that returns Tensor of unexpected shape.\"\"\"\n    loss = np.array([1., 2.], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      del labels, logits  # Unused\n      return tf.constant(loss)\n\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, loss_fn=_loss_fn)\n\n    logits = np.array([[-10., 10., 0.], [-15., 10., 0.]], dtype=np.float32)\n    labels = np.array([[1], [2]], dtype=np.int64)\n    actual_training_loss = head.create_loss(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[loss_fn must return Tensor of shape \\[D0, D1, ... DN, 1\\]\\. \\] '\n          r'\\[logits_shape: \\] \\[2 3\\] \\[loss_shape: \\] \\[2\\]'):\n        actual_training_loss.eval()\n\n  def test_eval_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3)\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.EVAL,\n          logits=np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32),\n          labels=None)\n\n  def test_eval(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = sum(cross_entropy(labels, logits)) = sum(10, 0) = 10.\n    expected_loss = 10.\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_loss / 2,\n        keys.ACCURACY: 0.5,  # 1 of 2 labels is correct.\n    }\n\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, and metrics.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_eval_metric_ops_with_head_name(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, name='some_multiclass_head')\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    expected_metric_keys = [\n        '{}/some_multiclass_head'.format(metric_keys.MetricKeys.LOSS_MEAN),\n        '{}/some_multiclass_head'.format(metric_keys.MetricKeys.ACCURACY)\n    ]\n    self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys())\n\n  def test_eval_with_regularization_losses(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes,\n        loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE)\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size\n    #                    = sum(10, 0) / 2 = 5.\n    expected_unregularized_loss = 5.\n    expected_regularized_loss = (\n        expected_unregularized_loss + expected_regularization_loss)\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        regularization_losses=regularization_losses)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_unregularized_loss,\n        keys.LOSS_REGULARIZATION: expected_regularization_loss,\n        keys.ACCURACY: 0.5,  # 1 of 2 labels is correct.\n    }\n\n    # Assert predictions, loss, and metrics.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_regularized_loss, loss, rtol=tol, atol=tol)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_eval_with_label_vocabulary_create_loss(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n    logits = [[10., 0, 0], [0, 10, 0]]\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = cross_entropy(labels, logits) = [10, 0].\n    expected_training_loss = 10.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n\n  def test_eval_with_label_vocabulary(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n\n    logits = [[10., 0, 0], [0, 10, 0]]\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = sum(cross_entropy(labels, logits)) = sum(10, 0) = 10.\n    expected_loss = 10.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_loss / 2,\n        keys.ACCURACY: 0.5,  # 1 of 2 labels is correct.\n    }\n\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_weighted_multi_example_eval(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)\n    labels = np.array(((1,), (2,), (2,)), dtype=np.int64)\n    weights_3x1 = np.array(((1.,), (2.,), (3.,)), dtype=np.float64)\n    # loss = sum(cross_entropy(labels, logits) * [1, 2, 3])\n    #      = sum([10, 10, 0] * [1, 2, 3]) = 30\n    expected_loss = 30.\n    spec = head.create_estimator_spec(\n        features={\n            'x': np.array(((42,),), dtype=np.int32),\n            'label_weights': weights_3x1,\n        },\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_loss / np.sum(weights_3x1),\n        # Weighted accuracy is 1 * 3.0 / sum weights = 0.5\n        keys.ACCURACY: 0.5,\n    }\n\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert loss, and metrics.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_train_create_loss(self):\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3)\n\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # unreduced_loss = cross_entropy(labels, logits) = [10, 0].\n    expected_unreduced_loss = [[10.], [0.]]\n    # Weights default to 1.\n    expected_weights = 1.\n    # training_loss = 1 * 10 + 1 * 0\n    expected_training_loss = 10.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features=features, mode=ModeKeys.TRAIN, logits=logits, labels=labels)\n    tol = 1e-2\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(expected_weights, actual_weights)\n\n  def test_train_create_loss_loss_reduction(self):\n    \"\"\"Tests create_loss with loss_reduction.\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3,\n        loss_reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)\n\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # unreduced_loss = cross_entropy(labels, logits) = [10, 0].\n    expected_unreduced_loss = [[10.], [0.]]\n    # Weights default to 1.\n    expected_weights = 1.\n    # training_loss = 1 * 10 + 1 * 0 / num_nonzero_weights\n    expected_training_loss = 10. / 2.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features=features, mode=ModeKeys.TRAIN, logits=logits, labels=labels)\n    tol = 1e-2\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(expected_weights, actual_weights)\n\n  def test_train_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3)\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32),\n          labels=None,\n          train_op_fn=_no_op_train_fn)\n\n  def test_train(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    # loss = sum(cross_entropy(labels, logits)) = sum(10, 0) = 10.\n    expected_loss = 10.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      _assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,\n          }, summary_str, tol)\n\n  def test_train_with_optimizer(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    expected_train_result = 'my_train_op'\n\n    class _Optimizer(object):\n\n      def minimize(self, loss, global_step):\n        del global_step\n        return tf.strings.join([\n            tf.constant(expected_train_result),\n            tf.strings.as_string(loss, precision=2)\n        ])\n\n    # loss = sum(cross_entropy(labels, logits)) = sum(10, 0) = 10.\n    expected_loss = 10.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        optimizer=_Optimizer())\n\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n\n  def test_train_with_update_ops(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(n_classes)\n\n    with tf.Graph().as_default():\n      w = tf.Variable(1)\n      update_op = w.assign_add(1)\n      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS,\n                                     update_op)\n\n      t = tf.Variable('')\n      expected_train_result = b'my_train_op'\n\n      def _train_op_fn(loss):\n        del loss\n        return t.assign(expected_train_result)\n\n      spec = head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32),\n          labels=np.array(((1,), (1,)), dtype=np.int64),\n          train_op_fn=_train_op_fn)\n\n      with self.cached_session() as sess:\n        _initialize_variables(self, spec.scaffold)\n        sess.run(spec.train_op)\n        w_value, t_value = sess.run([w, t])\n        self.assertEqual(2, w_value)\n        self.assertEqual(expected_train_result, t_value)\n\n  def test_train_summaries_with_head_name(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, name='some_multiclass_head')\n\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    # loss = sum(cross_entropy(labels, logits)) = sum(10, 0) = 10.\n    expected_loss = 10.\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    # Assert summaries.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      summary_str = sess.run(spec.scaffold.summary_op)\n      _assert_simple_summaries(\n          self, {\n              '{}/some_multiclass_head'.format(metric_keys.MetricKeys.LOSS):\n                  expected_loss,\n              '{}/some_multiclass_head'.format(\n                  metric_keys.MetricKeys.LOSS_MEAN):\n                  expected_loss / 2,\n          }, summary_str, tol)\n\n  def test_train_with_regularization_losses(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes,\n        loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE)\n\n    logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size\n    #                    = sum(10, 0) / 2 = 5.\n    # loss = unregularized_loss + regularization_loss = 7.\n    expected_loss = 7.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        regularization_losses=regularization_losses)\n\n    # Assert predictions, loss, train_op, and summaries.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      _assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS:\n                  expected_loss,\n              metric_keys.MetricKeys.LOSS_REGULARIZATION:\n                  (expected_regularization_loss),\n          }, summary_str, tol)\n\n  def test_train_one_dim_create_loss(self):\n    \"\"\"Tests create_loss with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, weight_column='label_weights')\n\n    logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)\n    labels_rank_1 = np.array((1, 2, 2,), dtype=np.int64)\n    weights_rank_1 = np.array((1., 2., 3.,), dtype=np.float64)\n    features = {\n        'x': np.array(((42,),), dtype=np.float32),\n        'label_weights': weights_rank_1\n    }\n\n    # unreduced_loss = cross_entropy(labels, logits) = [10, 10, 0].\n    expected_unreduced_loss = [[10.], [10.], [0.]]\n    # weights are reshaped to [3, 1] to match logits.\n    expected_weights = [[1.], [2.], [3.]]\n    # training_loss = 1 * 10 + 2 * 10 + 3 * 0 = 30.\n    expected_training_loss = 30.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1)\n    tol = 1e-2\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(expected_weights, actual_weights.eval())\n\n  def test_train_one_dim(self):\n    \"\"\"Tests train with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, weight_column='label_weights')\n\n    logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)\n    labels_rank_1 = np.array((1, 2, 2,), dtype=np.int64)\n    weights_rank_1 = np.array((1., 2., 3.,), dtype=np.float64)\n\n    self.assertEqual((3,), labels_rank_1.shape)\n    self.assertEqual((3,), weights_rank_1.shape)\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    # loss = sum(cross_entropy(labels, logits) * [1, 2, 3])\n    #      = sum([10, 10, 0] * [1, 2, 3]) = 30\n    expected_loss = 30.\n\n    features = {\n        'x': np.array(((42,),), dtype=np.float32),\n        'label_weights': weights_rank_1\n    }\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1,\n        train_op_fn=_train_op_fn)\n\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      _assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              metric_keys.MetricKeys.LOSS_MEAN:\n                  (expected_loss / np.sum(weights_rank_1)),\n          }, summary_str, tol)\n\n  def test_train_with_vocabulary_create_loss(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n\n    logits = [[10., 0, 0], [0, 10, 0]]\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = cross_entropy(labels, logits) = [10, 0].\n    expected_training_loss = 10.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.TRAIN, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n\n  def test_train_with_vocabulary(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n\n    logits = [[10., 0, 0], [0, 10, 0]]\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    # loss = sum(cross_entropy(labels, logits)) = sum(10, 0) = 10.\n    expected_loss = 10.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      loss = sess.run(spec.loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n\n  def test_weighted_multi_example_train(self):\n    n_classes = 3\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes, weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((10, 0, 0), (0, 10, 0), (0, 0, 10),), dtype=np.float32)\n    labels = np.array(((1,), (2,), (2,)), dtype=np.int64)\n    weights_3x1 = np.array(((1.,), (2.,), (3.,)), dtype=np.float64)\n    expected_train_result = 'my_train_op'\n    # loss = sum(cross_entropy(labels, logits) * [1, 2, 3])\n    #      = sum([10, 10, 0] * [1, 2, 3]) = 30\n    expected_loss = 30.\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    spec = head.create_estimator_spec(\n        features={\n            'x': np.array(((42,),), dtype=np.float32),\n            'label_weights': weights_3x1,\n        },\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      _assert_simple_summaries(\n          self,\n          {\n              metric_keys.MetricKeys.LOSS:\n                  expected_loss,\n              # loss mean = sum(cross_entropy(labels, logits) * [1,2,3]) / (1+2+3)\n              #      = sum([10, 10, 0] * [1, 2, 3]) / 6 = 30 / 6\n              metric_keys.MetricKeys.LOSS_MEAN:\n                  expected_loss / np.sum(weights_3x1),\n          },\n          summary_str,\n          tol)\n\n  def test_multi_dim_weighted_train_create_loss(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, weight_column='weights')\n\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n\n    # unreduced_loss = cross_entropy(labels, logits) = [[0, 12], [0, 15]].\n    expected_unreduced_loss = [[[0.], [12.]], [[0.], [15.]]]\n    # weights are reshaped to [2, 2, 1] to match logits.\n    expected_weights = [[[1.], [1.5]], [[2.], [2.5]]]\n    # training_loss = 1*0 + 1.5*12 + 2*0 + 2.5*15 = 55.5\n    expected_training_loss = 55.5\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels)\n    tol = 1e-2\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(expected_weights, actual_weights.eval())\n\n  def test_multi_dim_weighted_train(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, weight_column='weights')\n\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    # loss = cross_entropy(labels, logits) = [[0, 12], [0, 15]].\n    # weighted_sum_loss = 1*0 + 1.5*12 + 2*0 + 2.5*15 = 55.5\n    expected_loss = 55.5\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    # Assert predictions, loss, train_op, and summaries.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n\n  def test_multi_dim_train_weights_wrong_inner_dim(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 1].\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, weight_column='weights')\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[1.], [2.]], dtype=np.float32)\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\] \\[2 2 3\\] \\[weights_shape: \\] \\[2 1\\]'):\n        spec.loss.eval()\n\n  def test_multi_dim_train_weights_wrong_outer_dim(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 3].\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, weight_column='weights')\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[[1., 1.1, 1.2], [1.5, 1.6, 1.7]],\n                        [[2., 2.1, 2.2], [2.5, 2.6, 2.7]]])\n    weights_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights_placeholder},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\]\\s\\[2 2 3\\]\\s\\[weights_shape: \\]\\s\\[2 2 3\\]'):\n        spec.loss.eval({weights_placeholder: weights})\n\n  def test_multi_dim_weighted_eval(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n        n_classes=3, weight_column='weights')\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    # loss = cross_entropy(labels, logits) = [[0, 12], [0, 15]].\n    # weighted_sum_loss = 1*0 + 1.5*12 + 2*0 + 2.5*15 = 55.5\n    expected_loss = 55.5\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN:\n            expected_loss / np.sum(weights),\n        keys.ACCURACY:\n            (1. * 1. + 1.5 * 0. + 2. * 1. + 2.5 * 0.) / np.sum(weights),\n    }\n\n    # Assert predictions, loss, and metrics.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(tf.test.TestCase):\n\n  def setUp(self):\n    tf.compat.v1.reset_default_graph()\n\n  def test_threshold_too_small(self):\n    with self.assertRaisesRegexp(ValueError, r'thresholds not in \\(0, 1\\)'):\n      head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n          thresholds=(0., 0.5))\n\n  def test_threshold_too_large(self):\n    with self.assertRaisesRegexp(ValueError, r'thresholds not in \\(0, 1\\)'):\n      head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n          thresholds=(0.5, 1.))\n\n  def test_invalid_loss_reduction(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'Invalid loss_reduction: invalid_loss_reduction'):\n      head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n          loss_reduction='invalid_loss_reduction')\n    with self.assertRaisesRegexp(ValueError, r'Invalid loss_reduction: none'):\n      head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n          loss_reduction=tf.compat.v1.losses.Reduction.NONE)\n\n  def test_loss_fn_arg_labels_missing(self):\n\n    def _loss_fn(logits):\n      del logits  # Unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: labels\\. '\n        r'Given arguments: \\(\\'logits\\',\\)'):\n      head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n          loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_logits_missing(self):\n\n    def _loss_fn(labels):\n      del labels  # unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: logits\\. '\n        r'Given arguments: \\(\\'labels\\',\\)'):\n      head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n          loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_features_ok(self):\n\n    def _loss_fn(labels, logits, features):\n      del labels, logits, features  # Unused\n      head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n          loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_invalid(self):\n\n    def _loss_fn(labels, logits, name=None):\n      del labels, logits, name  # Unused\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'loss_fn has unexpected args: \\[\\'name\\'\\]'):\n      head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n          loss_fn=_loss_fn)\n\n  def test_invalid_logits_shape(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 1).\n    logits_2x2 = np.array(((45., 44.), (41., 42.),))\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42.,),))},\n          mode=ModeKeys.PREDICT,\n          logits=logits_2x2)\n\n    # Dynamic shape.\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42.,),))},\n        mode=ModeKeys.PREDICT,\n        logits=logits_placeholder)\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.predictions[prediction_keys.PredictionKeys.PROBABILITIES].eval(\n            {logits_placeholder: logits_2x2})\n\n  def test_invalid_labels_shape(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Labels and logits should be shape (batch_size, 1).\n    labels_2x2 = np.array(((45., 44.), (41., 42.),))\n    logits_2x1 = np.array(((45.,), (41.,),))\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):\n      head.create_loss(\n          features={'x': np.array(((42.,),))},\n          mode=ModeKeys.EVAL,\n          logits=logits_2x1,\n          labels=labels_2x2)\n\n    # Dynamic shape.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.create_loss(\n        features={'x': np.array(((42.,),))},\n        mode=ModeKeys.EVAL,\n        logits=logits_placeholder,\n        labels=labels_placeholder)[0]\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 1\\] \\[labels_shape: \\] \\[2 2\\]'):\n        training_loss.eval({\n            logits_placeholder: logits_2x1,\n            labels_placeholder: labels_2x2\n        })\n\n  def test_incompatible_labels_shape(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Both logits and labels should be shape (batch_size, 1).\n    values_2x1 = np.array(((0.,), (1.,),))\n    values_3x1 = np.array(((0.,), (1.,), (0.,),))\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError,\n                                 'logits and labels must have the same shape'):\n      head.create_loss(\n          features={'x': values_2x1},\n          mode=ModeKeys.EVAL,\n          logits=values_2x1,\n          labels=values_3x1)\n    with self.assertRaisesRegexp(ValueError,\n                                 'logits and labels must have the same shape'):\n      head.create_loss(\n          features={'x': values_2x1},\n          mode=ModeKeys.EVAL,\n          logits=values_3x1,\n          labels=values_2x1)\n\n    # Dynamic shape.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.create_loss(\n        features={'x': values_2x1},\n        mode=ModeKeys.EVAL,\n        logits=logits_placeholder,\n        labels=labels_placeholder)[0]\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[3 1\\] \\[labels_shape: \\] \\[2 1\\]'):\n        training_loss.eval({\n            labels_placeholder: values_2x1,\n            logits_placeholder: values_3x1\n        })\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 1\\] \\[labels_shape: \\] \\[3 1\\]'):\n        training_loss.eval({\n            labels_placeholder: values_3x1,\n            logits_placeholder: values_2x1\n        })\n\n  def test_name(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        name='foo')\n    self.assertEqual('foo', head.name)\n\n  def test_predict(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = [[0.3], [-0.4]]\n    expected_logistics = [[0.574443], [0.401312]]\n    expected_probabilities = [[0.425557, 0.574443], [0.598688, 0.401312]]\n    expected_class_ids = [[1], [0]]\n    expected_all_class_ids = [[0, 1]] * 2\n    expected_classes = [[b'1'], [b'0']]\n    expected_all_classes = [[b'0', b'1']] * 2\n    expected_export_classes = [[b'0', b'1']] * 2\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n\n    # Assert spec contains expected tensors.\n    self.assertIsNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNone(spec.train_op)\n    self.assertItemsEqual(\n        ('classification', 'regression', 'predict', _DEFAULT_SERVING_KEY),\n        spec.export_outputs.keys())\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(logits,\n                          predictions[prediction_keys.PredictionKeys.LOGITS])\n      self.assertAllClose(expected_logistics,\n                          predictions[prediction_keys.PredictionKeys.LOGISTIC])\n      self.assertAllClose(\n          expected_probabilities,\n          predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n      self.assertAllClose(expected_class_ids,\n                          predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n      self.assertAllEqual(expected_classes,\n                          predictions[prediction_keys.PredictionKeys.CLASSES])\n      self.assertAllClose(\n          expected_all_class_ids,\n          predictions[prediction_keys.PredictionKeys.ALL_CLASS_IDS])\n      self.assertAllEqual(\n          expected_all_classes,\n          predictions[prediction_keys.PredictionKeys.ALL_CLASSES])\n      self.assertAllClose(\n          expected_probabilities,\n          sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes))\n      self.assertAllClose(expected_logistics,\n                          sess.run(spec.export_outputs['regression'].value))\n\n  def test_predict_with_vocabulary_list(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        label_vocabulary=['aang', 'iroh'])\n\n    logits = [[1.], [0.]]\n    expected_classes = [[b'iroh'], [b'aang']]\n\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertAllEqual(\n          expected_classes,\n          sess.run(spec.predictions[prediction_keys.PredictionKeys.CLASSES]))\n\n  def test_eval_create_loss(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # loss = cross_entropy(labels, logits) = [0, 41].\n    expected_training_loss = 41.\n    # Create loss.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n\n  def test_eval_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.EVAL,\n          logits=np.array(((45,), (-41,),), dtype=np.float32),\n          labels=None)\n\n  def test_eval(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41\n        # loss_mean = loss/2 = 41./2 = 20.5\n        keys.LOSS_MEAN: 20.5,\n        keys.ACCURACY: 1. / 2,\n        keys.PRECISION: 1.,\n        keys.RECALL: 1. / 2,\n        keys.PREDICTION_MEAN: 1. / 2,\n        keys.LABEL_MEAN: 2. / 2,\n        keys.ACCURACY_BASELINE: 2. / 2,\n        keys.AUC: 0.,\n        keys.AUC_PR: 1.,\n    }\n\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(41., loss)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics)\n      self.assertAllClose(expected_metrics,\n                          {k: value_ops[k].eval() for k in value_ops})\n\n  def test_eval_metric_ops_with_head_name(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        name='some_binary_head')\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    expected_metric_keys = [\n        '{}/some_binary_head'.format(metric_keys.MetricKeys.LOSS_MEAN),\n        '{}/some_binary_head'.format(metric_keys.MetricKeys.ACCURACY),\n        '{}/some_binary_head'.format(metric_keys.MetricKeys.PRECISION),\n        '{}/some_binary_head'.format(metric_keys.MetricKeys.RECALL),\n        '{}/some_binary_head'.format(metric_keys.MetricKeys.PREDICTION_MEAN),\n        '{}/some_binary_head'.format(metric_keys.MetricKeys.LABEL_MEAN),\n        '{}/some_binary_head'.format(metric_keys.MetricKeys.ACCURACY_BASELINE),\n        '{}/some_binary_head'.format(metric_keys.MetricKeys.AUC),\n        '{}/some_binary_head'.format(metric_keys.MetricKeys.AUC_PR),\n    ]\n    self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys())\n\n  def test_eval_with_regularization_losses(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE)\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size\n    #                    = sum(0, 41) / 2 = 20.5\n    expected_unregularized_loss = 20.5\n    expected_regularized_loss = (\n        expected_unregularized_loss + expected_regularization_loss)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        regularization_losses=regularization_losses)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_unregularized_loss,\n        keys.LOSS_REGULARIZATION: expected_regularization_loss,\n        keys.ACCURACY: 1. / 2,\n        keys.PRECISION: 1.,\n        keys.RECALL: 1. / 2,\n        keys.PREDICTION_MEAN: 1. / 2,\n        keys.LABEL_MEAN: 2. / 2,\n        keys.ACCURACY_BASELINE: 2. / 2,\n        keys.AUC: 0.,\n        keys.AUC_PR: 1.,\n    }\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_regularized_loss, loss)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics)\n      self.assertAllClose(expected_metrics,\n                          {k: value_ops[k].eval() for k in value_ops})\n\n  def test_eval_with_vocabulary_list_create_loss(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        label_vocabulary=['aang', 'iroh'])\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # Create loss.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(41., training_loss.eval())\n\n  def test_eval_with_vocabulary_list(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        label_vocabulary=['aang', 'iroh'])\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      sess.run(update_ops)\n      self.assertAllClose(1. / 2,\n                          value_ops[metric_keys.MetricKeys.ACCURACY].eval())\n\n  def test_eval_with_thresholds_create_loss(self):\n    thresholds = [0.25, 0.5, 0.75]\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        thresholds=thresholds)\n    logits = np.array(((-1,), (1,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # probabilities[i] = 1/(1 + exp(-logits[i])) =>\n    # probabilities = [1/(1 + exp(1)), 1/(1 + exp(-1))] = [0.269, 0.731]\n    # loss = -ln(probabilities[label[i]])) = [-ln(0.269), -ln(0.731)]\n    #      = [1.31304389, 0.31334182]\n    # weighted sum loss = 1.62638571\n    expected_training_loss = 1.62638571\n    # Create loss.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n\n  def test_eval_with_thresholds(self):\n    thresholds = [0.25, 0.5, 0.75]\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        thresholds=thresholds)\n    logits = np.array(((-1,), (1,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    # probabilities[i] = 1/(1 + exp(-logits[i])) =>\n    # probabilities = [1/(1 + exp(1)), 1/(1 + exp(-1))] = [0.269, 0.731]\n    # loss = -sum(ln(probabilities[label[i]])) = -ln(0.269) -ln(0.731)\n    #      = 1.62652338\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: 1.62652338 / 2.,\n        keys.ACCURACY: 1. / 2,\n        keys.PRECISION: 1.,\n        keys.RECALL: .5,\n        keys.PREDICTION_MEAN: 1. / 2,\n        keys.LABEL_MEAN: 2. / 2,\n        keys.ACCURACY_BASELINE: 2. / 2,\n        keys.AUC: 0.,\n        keys.AUC_PR: 1.,\n        keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1.,\n        keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1.,\n        keys.RECALL_AT_THRESHOLD % thresholds[0]: 1.,\n        keys.ACCURACY_AT_THRESHOLD % thresholds[1]: .5,\n        keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1.,\n        keys.RECALL_AT_THRESHOLD % thresholds[1]: .5,\n        keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 0.,\n        keys.PRECISION_AT_THRESHOLD % thresholds[2]: 0.,\n        keys.RECALL_AT_THRESHOLD % thresholds[2]: 0.,\n    }\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(1.62652338, loss)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          atol=tol,\n          rtol=tol)\n\n  def test_train_create_loss(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # unreduced_loss = cross_entropy(labels, logits) = [0, 41]\n    expected_unreduced_loss = [[0.], [41.]]\n    # weights default to 1.\n    expected_weights = 1.\n    # training loss = 1 * 0 + 1 * 41\n    expected_training_loss = 41.\n    # Create loss.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features=features, mode=ModeKeys.TRAIN, logits=logits, labels=labels)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_training_loss, training_loss.eval())\n      self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())\n      self.assertAllClose(expected_weights, actual_weights)\n\n  def test_train_create_loss_loss_reduction(self):\n    \"\"\"Tests create_loss with loss_reduction.\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        loss_reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # unreduced_loss = cross_entropy(labels, logits) = [0, 41]\n    expected_unreduced_loss = [[0.], [41.]]\n    # weights default to 1.\n    expected_weights = 1.\n    # training loss = (1 * 0 + 1 * 41) / num_nonzero_weights\n    expected_training_loss = 41. / 2.\n    # Create loss.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features=features, mode=ModeKeys.TRAIN, logits=logits, labels=labels)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_training_loss, training_loss.eval())\n      self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())\n      self.assertAllClose(expected_weights, actual_weights)\n\n  def test_eval_create_loss_loss_fn(self):\n    \"\"\"Tests head.create_loss for eval mode and custom loss_fn.\"\"\"\n    loss = np.array([[1.], [2.]], dtype=np.float32)\n    logits_input = np.array([[-10.], [10.]], dtype=np.float32)\n    labels_input = np.array([[1], [0]], dtype=np.int64)\n\n    def _loss_fn(labels, logits):\n      check_labels = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(labels, labels_input)), data=[labels])\n      check_logits = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(logits, logits_input)), data=[logits])\n      with tf.control_dependencies([check_labels, check_logits]):\n        return tf.constant(loss)\n\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        loss_fn=_loss_fn)\n\n    actual_training_loss = head.create_loss(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.EVAL,\n        logits=logits_input,\n        labels=labels_input)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(np.sum(loss), actual_training_loss.eval())\n\n  def test_eval_create_loss_loss_fn_wrong_shape(self):\n    \"\"\"Tests custom loss_fn that returns Tensor of unexpected shape.\"\"\"\n    loss = np.array([1., 2.], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      del labels, logits  # Unused\n      return tf.constant(loss)\n\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        loss_fn=_loss_fn)\n\n    logits = np.array([[-10.], [10.]], dtype=np.float32)\n    labels = np.array([[1], [0]], dtype=np.int64)\n    actual_training_loss = head.create_loss(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[loss_fn must return Tensor of shape \\[D0, D1, ... DN, 1\\]\\. \\] '\n          r'\\[logits_shape: \\] \\[2 1\\] \\[loss_shape: \\] \\[2\\]'):\n        actual_training_loss.eval()\n\n  def test_train_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (-41,),), dtype=np.float32),\n          labels=None,\n          train_op_fn=_no_op_train_fn)\n\n  def test_train(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41\n    expected_loss = 41.\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      _assert_simple_summaries(\n          self,\n          {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              # loss_mean = loss/2 = 41/2 = 20.5\n              metric_keys.MetricKeys.LOSS_MEAN: 20.5,\n          },\n          summary_str)\n\n  def test_train_with_optimizer(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41\n    expected_loss = 41.\n\n    class _Optimizer(object):\n\n      def minimize(self, loss, global_step):\n        del global_step\n        with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n            tf.cast(expected_loss, dtype=tf.dtypes.float32),\n            tf.cast(loss, dtype=tf.dtypes.float32),\n            name='assert_loss'),)):\n          return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        optimizer=_Optimizer())\n\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n\n  def test_train_with_update_ops(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    with tf.Graph().as_default():\n      w = tf.Variable(1)\n      update_op = w.assign_add(1)\n      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS,\n                                     update_op)\n\n      t = tf.Variable('')\n      expected_train_result = b'my_train_op'\n\n      def _train_op_fn(loss):\n        del loss\n        return t.assign(expected_train_result)\n\n      spec = head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (-41,),), dtype=np.float32),\n          labels=np.array(((1,), (1,),), dtype=np.float64),\n          train_op_fn=_train_op_fn)\n\n      with self.cached_session() as sess:\n        _initialize_variables(self, spec.scaffold)\n        sess.run(spec.train_op)\n        w_value, t_value = sess.run([w, t])\n        self.assertEqual(2, w_value)\n        self.assertEqual(expected_train_result, t_value)\n\n  def test_train_summaries_with_head_name(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        name='some_binary_head')\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # loss = sum(cross_entropy(labels, logits)) = sum(0, 41) = 41\n    expected_loss = 41.\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n    # Assert summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      summary_str = sess.run(spec.scaffold.summary_op)\n      _assert_simple_summaries(\n          self,\n          {\n              '{}/some_binary_head'.format(metric_keys.MetricKeys.LOSS):\n                  expected_loss,\n              # loss_mean = loss/2 = 41/2 = 20.5\n              '{}/some_binary_head'.format(metric_keys.MetricKeys.LOSS_MEAN):\n                  20.5,\n          },\n          summary_str)\n\n  def test_train_with_regularization_losses(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE)\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size\n    #                    = sum(0, 41) / 2 = 20.5\n    # loss = unregularized_loss + regularization_loss = 7.\n    expected_loss = 22.5\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        regularization_losses=regularization_losses)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      _assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS:\n                  expected_loss,\n              metric_keys.MetricKeys.LOSS_REGULARIZATION:\n                  (expected_regularization_loss),\n          }, summary_str)\n\n  def test_float_labels_invalid_values(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[1.2], [0.4]], dtype=np.float32)\n    features = {'x': np.array([[42]], dtype=np.float32)}\n    with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,\n                                 r'Labels must <= n_classes - 1'):\n      training_loss = head.create_loss(\n          features=features, mode=ModeKeys.TRAIN, logits=logits,\n          labels=labels)[0]\n\n  def test_float_labels_train_create_loss(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[0.8], [0.4]], dtype=np.float32)\n    features = {'x': np.array([[42]], dtype=np.float32)}\n    # loss = cross_entropy(labels, logits)\n    #      = -label[i]*sigmoid(logit[i]) -(1-label[i])*sigmoid(-logit[i])\n    #      = [-0.8 * log(sigmoid(0.5)) -0.2 * log(sigmoid(-0.5)),\n    #         -0.4 * log(sigmoid(-0.3)) -0.6 * log(sigmoid(0.3))]\n    #      = [0.57407698418, 0.67435524446]\n    # weighted sum loss = 0.57407698418 + 0.67435524446\n    expected_training_loss = 1.24843222864\n    # Create loss.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.TRAIN, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n\n  def test_float_labels_train(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[0.8], [0.4]], dtype=np.float32)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array([[42]], dtype=np.float32)}\n    # loss = sum(cross_entropy(labels, logits))\n    #      = sum(-label[i]*sigmoid(logit[i]) -(1-label[i])*sigmoid(-logit[i]))\n    #      = -0.8 * log(sigmoid(0.5)) -0.2 * log(sigmoid(-0.5))\n    #        -0.4 * log(sigmoid(-0.3)) -0.6 * log(sigmoid(0.3))\n    #      = 1.2484322\n    expected_loss = 1.2484322\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((dnn_testing_utils_v1.assert_close(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32)),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAlmostEqual(expected_loss, loss, delta=1.e-5)\n      self.assertEqual(expected_train_result, train_result)\n\n  def test_float_labels_eval_create_loss(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[0.8], [0.4]], dtype=np.float32)\n    features = {'x': np.array([[42]], dtype=np.float32)}\n    # loss = cross_entropy(labels, logits)\n    #      = -label[i]*sigmoid(logit[i]) -(1-label[i])*sigmoid(-logit[i])\n    #      = [-0.8 * log(sigmoid(0.5)) -0.2 * log(sigmoid(-0.5)),\n    #         -0.4 * log(sigmoid(-0.3)) -0.6 * log(sigmoid(0.3))]\n    #      = [0.57407698418, 0.67435524446]\n    # weighted sum loss = 0.57407698418 + 0.67435524446\n    expected_training_loss = 1.24843222864\n    # Create loss.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n\n  def test_float_labels_eval(self):\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[0.8], [0.4]], dtype=np.float32)\n    features = {'x': np.array([[42]], dtype=np.float32)}\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    # loss = sum(cross_entropy(labels, logits))\n    #      = sum(-label[i]*sigmoid(logit[i]) -(1-label[i])*sigmoid(-logit[i]))\n    #      = -0.8 * log(sigmoid(0.5)) -0.2 * log(sigmoid(-0.5))\n    #        -0.4 * log(sigmoid(-0.3)) -0.6 * log(sigmoid(0.3))\n    #      = 1.2484322\n    expected_loss = 1.2484322\n\n    # Assert loss.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAlmostEqual(expected_loss, loss, delta=1.e-5)\n      self.assertAlmostEqual(expected_loss / 2.,\n                             metrics[metric_keys.MetricKeys.LOSS_MEAN])\n\n  def test_weighted_multi_example_predict(self):\n    \"\"\"3 examples, 1 batch.\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.int32)\n    spec = head.create_estimator_spec(\n        features={\n            'x': np.array(((42,), (43,), (44,)), dtype=np.int32),\n            'label_weights': np.array(((1.,), (.1,), (1.5,)), dtype=np.float32),\n        },\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(\n          logits.astype(np.float32),\n          predictions[prediction_keys.PredictionKeys.LOGITS])\n      self.assertAllClose(\n          _sigmoid(logits).astype(np.float32),\n          predictions[prediction_keys.PredictionKeys.LOGISTIC])\n      self.assertAllClose(\n          [[0., 1.], [1., 0.], [0., 1.]],\n          predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n      self.assertAllClose([[1], [0], [1]],\n                          predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n      self.assertAllEqual([[b'1'], [b'0'], [b'1']],\n                          predictions[prediction_keys.PredictionKeys.CLASSES])\n\n  def test_weighted_multi_example_eval(self):\n    \"\"\"3 examples, 1 batch.\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.int32)\n    spec = head.create_estimator_spec(\n        features={\n            'x': np.array(((42,), (43,), (44,)), dtype=np.int32),\n            'label_weights': np.array(((1.,), (.1,), (1.5,)), dtype=np.float32),\n        },\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=np.array(((1,), (1,), (0,)), dtype=np.int32))\n\n    # label_mean = (1*1 + .1*1 + 1.5*0)/(1 + .1 + 1.5) = 1.1/2.6\n    #            = .42307692307\n    expected_label_mean = .42307692307\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # losses = label_weights*cross_entropy(labels, logits)\n        #        = (1*0 + .1*41 + 1.5*44) = (1, 4.1, 66)\n        # loss = sum(losses) = 1 + 4.1 + 66 = 70.1\n        # loss_mean = loss/sum(label_weights) = 70.1/(1 + .1 + 1.5)\n        #           = 70.1/2.6 = 26.9615384615\n        keys.LOSS_MEAN: 26.9615384615,\n        # accuracy = (1*1 + .1*0 + 1.5*0)/(1 + .1 + 1.5) = 1/2.6 = .38461538461\n        keys.ACCURACY: .38461538461,\n        keys.PRECISION: 1. / 2.5,\n        keys.RECALL: 1. / 1.1,\n        # prediction_mean = (1*1 + .1*0 + 1.5*1)/(1 + .1 + 1.5) = 2.5/2.6\n        #                 = .96153846153\n        keys.PREDICTION_MEAN: .96153846153,\n        keys.LABEL_MEAN: expected_label_mean,\n        keys.ACCURACY_BASELINE: 1 - expected_label_mean,\n        keys.AUC: .45454565,\n        keys.AUC_PR: .6737757325172424,\n    }\n\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(70.1, loss)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics)\n      self.assertAllClose(expected_metrics,\n                          {k: value_ops[k].eval() for k in value_ops})\n\n  def test_train_one_dim_create_loss(self):\n    \"\"\"Tests create_loss with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)\n    labels_rank_1 = np.array((1., 1., 0.,))\n    weights_rank_1 = np.array(((1., .1, 1.5,)), dtype=np.float64)\n    features = {\n        'x': np.array(((42.,), (43.,), (44.,)), dtype=np.float32),\n        'label_weights': weights_rank_1,\n    }\n    # unreduced_loss = cross_entropy(labels, logits) = [0, 41, 44]\n    expected_unreduced_loss = [[0.], [41.], [44.]]\n    # weights are reshaped to [3, 1] to match logits.\n    expected_weights = [[1.], [.1], [1.5]]\n    # training loss = 1 * 0 + .1 * 41 + 1.5 * 44\n    expected_training_loss = 70.1\n    # Create loss.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss.eval(), rtol=1e-2, atol=1e-2)\n      self.assertAllClose(expected_weights, actual_weights.eval())\n\n  def test_train_one_dim(self):\n    \"\"\"Tests train with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)\n    labels_rank_1 = np.array((1., 1., 0.,))\n    weights_rank_1 = np.array(((1., .1, 1.5,)), dtype=np.float64)\n    self.assertEqual((3,), labels_rank_1.shape)\n    self.assertEqual((3,), weights_rank_1.shape)\n    features = {\n        'x': np.array(((42.,), (43.,), (44.,)), dtype=np.float32),\n        'label_weights': weights_rank_1,\n    }\n    expected_train_result = b'my_train_op'\n    # losses = label_weights*cross_entropy(labels, logits)\n    #        = (1*0 + .1*41 + 1.5*44) = (1, 4.1, 66)\n    # loss = sum(losses) = 1 + 4.1 + 66 = 70.1\n    expected_loss = 70.1\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1,\n        train_op_fn=_train_op_fn)\n\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertIsNotNone(spec.train_op)\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      _assert_simple_summaries(\n          self,\n          {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              # loss_mean = loss/sum(label_weights) = 70.1/(1 + .1 + 1.5)\n              #           = 70.1/2.6 = 26.9615384615\n              metric_keys.MetricKeys.LOSS_MEAN: 26.9615384615,\n          },\n          summary_str)\n\n  def test_weighted_multi_example_train(self):\n    \"\"\"3 examples, 1 batch.\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)\n    expected_train_result = b'my_train_op'\n    # losses = label_weights*cross_entropy(labels, logits)\n    #        = (1*0 + .1*41 + 1.5*44) = (1, 4.1, 66)\n    # loss = sum(losses) = 1 + 4.1 + 66 = 70.1\n    expected_loss = 70.1\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features={\n            'x': np.array(((42.,), (43.,), (44.,)), dtype=np.float32),\n            'label_weights': np.array(((1.,), (.1,), (1.5,)), dtype=np.float64),\n        },\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=np.array(((1.,), (1.,), (0.,))),\n        train_op_fn=_train_op_fn)\n\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertIsNotNone(spec.train_op)\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      _assert_simple_summaries(\n          self,\n          {\n              metric_keys.MetricKeys.LOSS:\n                  expected_loss,\n              # loss_mean = loss/sum(label_weights) = 70.1/(1 + .1 + 1.5)\n              #           = 70.1/2.6 = 26.9615384615\n              metric_keys.MetricKeys.LOSS_MEAN:\n                  26.9615384615,\n          },\n          summary_str)\n\n  def test_multi_dim_weighted_train_create_loss(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='weights')\n\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    # unreduced_loss = cross_entropy(labels, logits) = [[10, 0], [0, 12]].\n    expected_unreduced_loss = [[[10.], [0.]], [[0.], [12.]]]\n    # Weights are reshaped to [2, 2, 1] to match logits.\n    expected_weights = [[[1.], [1.5]], [[2.], [2.5]]]\n    # training_loss = 1*10 + 1.5*0 + 2*0 + 2.5*12 = 40\n    expected_training_loss = 40.\n    # Create loss.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels)\n    tol = 1e-2\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(expected_weights, actual_weights.eval())\n\n  def test_multi_dim_weighted_train(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='weights')\n\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    # loss = cross_entropy(labels, logits) = [[10, 0], [0, 12]].\n    # weighted_sum_loss = 1*10 + 1.5*0 + 2*0 + 2.5*12 = 40\n    expected_loss = 40.\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    # Assert predictions, loss, train_op, and summaries.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n\n  def test_multi_dim_train_weights_wrong_inner_dim(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 1].\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='weights')\n\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights = np.array([[1.], [2.]], dtype=np.float32)\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\] \\[2 2 1\\] \\[weights_shape: \\] \\[2 1\\]'):\n        spec.loss.eval()\n\n  def test_multi_dim_train_weights_wrong_outer_dim(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 2, 2].\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='weights')\n\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights_placeholder},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\]\\s\\[2 2 1\\]\\s\\[weights_shape: \\]\\s\\[2 2 2\\]'):\n        spec.loss.eval({\n            weights_placeholder:\n                np.array([[[1., 1.1], [1.5, 1.6]], [[2., 2.1], [2.5, 2.6]]])\n        })\n\n  def test_multi_dim_weighted_eval(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(\n        weight_column='weights')\n\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    # loss = cross_entropy(labels, logits) = [[10, 0], [0, 12]].\n    # weighted_sum_loss = 1*10 + 1.5*0 + 2*0 + 2.5*12 = 40\n    expected_loss = 40.\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_loss / np.sum(weights),\n        keys.ACCURACY:\n          (1. * 0. + 1.5 * 1. + 2. * 1. + 2.5 * 0.) / np.sum(weights),\n        keys.PRECISION: 2.0 / 3.0,\n        keys.RECALL: 2.0 / 4.5,\n        keys.PREDICTION_MEAN:\n          (1. * 1 + 1.5 * 0 + 2. * 1 + 2.5 * 0) / np.sum(weights),\n        keys.LABEL_MEAN:\n          (1. * 0 + 1.5 * 0 + 2. * 1 + 2.5 * 1) / np.sum(weights),\n        keys.ACCURACY_BASELINE:\n          (1. * 0 + 1.5 * 0 + 2. * 1 + 2.5 * 1) / np.sum(weights),\n        # We cannot reliably calculate AUC with only 4 data points, but the\n        # values should not change because of backwards-compatibility.\n        keys.AUC: 0.5222,\n        keys.AUC_PR: 0.7341,\n    }\n\n    tol = 1e-2\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, metrics = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass RegressionHead(tf.test.TestCase):\n\n  def setUp(self):\n    tf.compat.v1.reset_default_graph()\n\n  def test_invalid_label_dimension(self):\n    with self.assertRaisesRegexp(ValueError, r'Invalid label_dimension'):\n      head_lib._regression_head(label_dimension=-1)\n    with self.assertRaisesRegexp(ValueError, r'Invalid label_dimension'):\n      head_lib._regression_head(label_dimension=0)\n\n  def test_invalid_loss_reduction(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'Invalid loss_reduction: invalid_loss_reduction'):\n      head_lib._regression_head(loss_reduction='invalid_loss_reduction')\n    with self.assertRaisesRegexp(ValueError, r'Invalid loss_reduction: none'):\n      head_lib._regression_head(\n          loss_reduction=tf.compat.v1.losses.Reduction.NONE)\n\n  def test_loss_fn_arg_labels_missing(self):\n\n    def _loss_fn(logits):\n      del logits  # Unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: labels\\. '\n        r'Given arguments: \\(\\'logits\\',\\)'):\n      head_lib._regression_head(loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_logits_missing(self):\n\n    def _loss_fn(labels):\n      del labels  # unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: logits\\. '\n        r'Given arguments: \\(\\'labels\\',\\)'):\n      head_lib._regression_head(loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_features_ok(self):\n\n    def _loss_fn(labels, logits, features):\n      del labels, logits, features  # Unused\n      head_lib._regression_head(loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_invalid(self):\n\n    def _loss_fn(labels, logits, name=None):\n      del labels, logits, name  # Unused\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'loss_fn has unexpected args: \\[\\'name\\'\\]'):\n      head_lib._regression_head(loss_fn=_loss_fn)\n\n  def test_invalid_logits(self):\n    head = head_lib._regression_head(label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n    logits_1d = np.array(((45.,), (41.,),))\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42.,),))},\n          mode=ModeKeys.PREDICT,\n          logits=logits_1d)\n\n    # Dynamic shape.\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42.,),))},\n        mode=ModeKeys.PREDICT,\n        logits=logits_placeholder)\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.predictions[prediction_keys.PredictionKeys.PREDICTIONS].eval(\n            {logits_placeholder: logits_1d})\n\n  def test_incompatible_labels_eval(self):\n    head = head_lib._regression_head(label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n    values_3d = np.array(((45., 46., 47.), (41., 42., 43.),))\n    values_1d = np.array(((43.,), (44.,),))\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):\n      head.create_loss(\n          features={'x': values_1d},\n          mode=ModeKeys.EVAL,\n          logits=values_3d,\n          labels=values_1d)\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      head.create_estimator_spec(\n          features={'x': values_3d},\n          labels=values_3d,\n          mode=ModeKeys.EVAL,\n          logits=values_1d,\n          train_op_fn=None)\n\n    # Dynamic shape.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': values_1d},\n        mode=ModeKeys.EVAL,\n        logits=logits_placeholder,\n        labels=labels_placeholder)\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.loss.eval({\n            labels_placeholder: values_3d,\n            logits_placeholder: values_1d\n        })\n    training_loss = head.create_loss(\n        features={'x': values_1d},\n        mode=ModeKeys.EVAL,\n        logits=logits_placeholder,\n        labels=labels_placeholder)[0]\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 3\\] \\[labels_shape: \\] \\[2 1\\]'):\n        training_loss.eval({\n            labels_placeholder: values_1d,\n            logits_placeholder: values_3d\n        })\n\n  def test_incompatible_labels_train(self):\n    head = head_lib._regression_head(label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n    values_3d = np.array(((45., 46., 47.), (41., 42., 43.),))\n    values_1d = np.array(((43.,), (44.,),))\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):\n      head.create_loss(\n          features={'x': values_1d},\n          mode=ModeKeys.TRAIN,\n          logits=values_3d,\n          labels=values_1d)\n\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      head.create_estimator_spec(\n          features={'x': values_3d},\n          mode=ModeKeys.TRAIN,\n          logits=values_1d,\n          labels=values_3d,\n          train_op_fn=lambda x: x)\n\n    # Dynamic shape.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': values_1d},\n        mode=ModeKeys.TRAIN,\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        train_op_fn=lambda x: x)\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.loss.eval({\n            labels_placeholder: values_3d,\n            logits_placeholder: values_1d\n        })\n    training_loss = head.create_loss(\n        features={'x': values_1d},\n        mode=ModeKeys.TRAIN,\n        logits=logits_placeholder,\n        labels=labels_placeholder)[0]\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 3\\] \\[labels_shape: \\] \\[2 1\\]'):\n        training_loss.eval({\n            labels_placeholder: values_1d,\n            logits_placeholder: values_3d\n        })\n\n  def test_name(self):\n    head = head_lib._regression_head(name='foo')\n    self.assertEqual('foo', head.name)\n\n  def test_predict(self):\n    head = head_lib._regression_head()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,),), dtype=np.int32)\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42.,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n\n    # Assert spec contains expected tensors.\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), spec.predictions.keys())\n    self.assertEqual(tf.dtypes.float32, spec.predictions[prediction_key].dtype)\n    self.assertIsNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNone(spec.train_op)\n    default_serving_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n    self.assertItemsEqual((default_serving_key, 'predict', 'regression'),\n                          spec.export_outputs.keys())\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions.\n    with self.cached_session():\n      _initialize_variables(self, spec.scaffold)\n      self.assertAllClose(logits, spec.predictions[prediction_key].eval())\n      self.assertAllClose(logits,\n                          spec.export_outputs[default_serving_key].value.eval())\n      self.assertAllClose(logits,\n                          spec.export_outputs['regression'].value.eval())\n      self.assertAllClose(\n          logits, spec.export_outputs['predict'].outputs['predictions'].eval())\n\n  def test_predict_with_inverse_link_fn(self):\n\n    def _inverse_link_fn(logits):\n      return logits - 10.\n\n    head = head_lib._regression_head(inverse_link_fn=_inverse_link_fn)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,),), dtype=np.int32)\n    expected_predictions = np.array(((35,), (31,),), dtype=np.int32)\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42.,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n\n    # Assert spec contains expected tensors.\n    keys = prediction_keys.PredictionKeys\n    self.assertItemsEqual((keys.PREDICTIONS, keys.LOGITS),\n                          spec.predictions.keys())\n    self.assertEqual(tf.dtypes.float32,\n                     spec.predictions[keys.PREDICTIONS].dtype)\n    self.assertEqual(tf.dtypes.float32, spec.predictions[keys.LOGITS].dtype)\n    default_serving_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n    self.assertItemsEqual((default_serving_key, 'predict', 'regression'),\n                          spec.export_outputs.keys())\n\n    # Assert predictions.\n    with self.cached_session():\n      _initialize_variables(self, spec.scaffold)\n      self.assertAllClose(expected_predictions,\n                          spec.predictions[keys.PREDICTIONS].eval())\n      self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval())\n      self.assertAllClose(expected_predictions,\n                          spec.export_outputs[default_serving_key].value.eval())\n      self.assertAllClose(expected_predictions,\n                          spec.export_outputs['regression'].value.eval())\n      self.assertAllClose(\n          expected_predictions,\n          spec.export_outputs['predict'].outputs['predictions'].eval())\n      self.assertAllClose(\n          logits, spec.export_outputs['predict'].outputs['logits'].eval())\n\n  def test_eval_create_loss(self):\n    head = head_lib._regression_head()\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # Create loss.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      # loss = [(43-45)^2, (44-41)] = [4, 9]\n      self.assertAllClose(13., training_loss.eval())\n\n  def test_eval_create_loss_loss_fn(self):\n    \"\"\"Tests head.create_loss for eval mode and custom loss_fn.\"\"\"\n    loss = np.array([[0., 1.], [2., 3.]], dtype=np.float32)\n    logits_input = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32)\n    labels_input = np.array([[1., 0.], [2., -1.]], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      check_labels = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(labels, labels_input)), data=[labels])\n      check_logits = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(logits, logits_input)), data=[logits])\n      with tf.control_dependencies([check_labels, check_logits]):\n        return tf.constant(loss)\n\n    head = head_lib._regression_head(label_dimension=2, loss_fn=_loss_fn)\n\n    actual_training_loss = head.create_loss(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.EVAL,\n        logits=logits_input,\n        labels=labels_input)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(np.sum(loss), actual_training_loss.eval())\n\n  def test_eval_create_loss_loss_fn_wrong_shape(self):\n    \"\"\"Tests custom loss_fn that returns Tensor of unexpected shape.\"\"\"\n    loss = np.array([[1.], [2.]], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      del labels, logits  # Unused\n      return tf.constant(loss)\n\n    head = head_lib._regression_head(label_dimension=2, loss_fn=_loss_fn)\n\n    logits = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32)\n    labels = np.array([[1., 0.], [2., -1.]], dtype=np.float32)\n    actual_training_loss = head.create_loss(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[loss_fn must return Tensor of shape \\[D0, D1, ... DN, 2\\]\\. \\] '\n          r'\\[logits_shape: \\] \\[2 2\\] \\[loss_shape: \\] \\[2 1\\]'):\n        actual_training_loss.eval()\n\n  def test_eval_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib._regression_head()\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.EVAL,\n          logits=np.array(((45,), (41,),), dtype=np.float32),\n          labels=None)\n\n  def test_eval(self):\n    head = head_lib._regression_head()\n    self.assertEqual(1, head.logits_dimension)\n\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    # Assert spec contains expected tensors.\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), spec.predictions.keys())\n    self.assertEqual(tf.dtypes.float32, spec.predictions[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,\n                           metric_keys.MetricKeys.PREDICTION_MEAN,\n                           metric_keys.MetricKeys.LABEL_MEAN),\n                          spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[\n          metric_keys.MetricKeys.LOSS_MEAN]\n      predictions, loss, loss_mean = sess.run(\n          (spec.predictions[prediction_key], spec.loss, loss_mean_update_op))\n      self.assertAllClose(logits, predictions)\n      # loss = (43-45)^2 + (44-41)^2 = 4+9 = 13\n      self.assertAllClose(13., loss)\n      # loss_mean = loss/2 = 13/2 = 6.5\n      expected_loss_mean = 6.5\n      # Check results of both update (in `loss_mean`) and value ops.\n      self.assertAllClose(expected_loss_mean, loss_mean)\n      self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval())\n\n  def test_eval_metric_ops_with_head_name_for_regression(self):\n    head = head_lib._regression_head(name='some_regression_head')\n    logits = np.array(((1,), (9,)), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    expected_metric_keys = [\n        '{}/some_regression_head'.format(metric_keys.MetricKeys.LOSS_MEAN),\n        '{}/some_regression_head'.format(\n            metric_keys.MetricKeys.PREDICTION_MEAN),\n        '{}/some_regression_head'.format(metric_keys.MetricKeys.LABEL_MEAN),\n    ]\n    self.assertItemsEqual(expected_metric_keys, spec.eval_metric_ops.keys())\n\n  def test_eval_with_regularization_losses(self):\n    head = head_lib._regression_head(\n        loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE)\n    self.assertEqual(1, head.logits_dimension)\n\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = ((43-45)^2 + (44-41)^2) / batch_size\n    #                    = (4 + 9) / 2 = 6.5\n    expected_unregularized_loss = 6.5\n    expected_regularized_loss = (\n        expected_unregularized_loss + expected_regularization_loss)\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        regularization_losses=regularization_losses)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_unregularized_loss,\n        keys.LOSS_REGULARIZATION: expected_regularization_loss,\n        keys.PREDICTION_MEAN: (45 + 41) / 2.0,\n        keys.LABEL_MEAN: (43 + 44) / 2.0,\n    }\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n      predictions, loss, metrics = sess.run(\n          (spec.predictions[prediction_key], spec.loss, update_ops))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_regularized_loss, loss)\n      # Check results of both update (in `metrics`) and value ops.\n      self.assertAllClose(expected_metrics, metrics)\n      self.assertAllClose(expected_metrics,\n                          {k: value_ops[k].eval() for k in value_ops})\n\n  def test_train_create_loss(self):\n    head = head_lib._regression_head()\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # unreduced_loss = [(43-45)^2, (44-41)] = [4, 9]\n    expected_unreduced_loss = [[4.], [9.]]\n    # weights default to 1.\n    expected_weights = 1\n    # training_loss = 1 * 4 + 1 * 9 = 13\n    expected_training_loss = 13.\n    # Create loss.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features=features, mode=ModeKeys.TRAIN, logits=logits, labels=labels)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_training_loss, training_loss.eval())\n      self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())\n      self.assertAllClose(expected_weights, actual_weights)\n\n  def test_train_create_loss_loss_reduction(self):\n    \"\"\"Tests create_loss with loss_reduction.\"\"\"\n    head = head_lib._regression_head(\n        loss_reduction=tf.compat.v1.losses.Reduction.SUM_BY_NONZERO_WEIGHTS)\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # unreduced_loss = [(43-45)^2, (44-41)] = [4, 9]\n    expected_unreduced_loss = [[4.], [9.]]\n    # weights default to 1.\n    expected_weights = 1\n    # training_loss = (1 * 4 + 1 * 9) / num_nonzero_weights\n    expected_training_loss = 13. / 2.\n    # Create loss.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features=features, mode=ModeKeys.TRAIN, logits=logits, labels=labels)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_training_loss, training_loss.eval())\n      self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())\n      self.assertAllClose(expected_weights, actual_weights)\n\n  def test_train_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib._regression_head()\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (41,),), dtype=np.float32),\n          labels=None,\n          train_op_fn=_no_op_train_fn)\n\n  def test_train(self):\n    head = head_lib._regression_head()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43.,), (44.,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42.,),), dtype=np.float32)}\n    # loss = (43-45)^2 + (44-41)^2 = 4 + 9 = 13\n    expected_loss = 13\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    # Assert spec contains expected tensors.\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), spec.predictions.keys())\n    self.assertEqual(tf.dtypes.float32, spec.predictions[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      _assert_simple_summaries(\n          self,\n          {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              # loss_mean = loss/2 = 13/2 = 6.5\n              metric_keys.MetricKeys.LOSS_MEAN: 6.5,\n          },\n          summary_str)\n\n  def test_train_with_optimizer(self):\n    head = head_lib._regression_head()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43.,), (44.,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42.,),), dtype=np.float32)}\n    # loss = (43-45)^2 + (44-41)^2 = 4 + 9 = 13\n    expected_loss = 13\n\n    class _Optimizer(object):\n\n      def minimize(self, loss, global_step):\n        del global_step\n        with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n            tf.cast(expected_loss, dtype=tf.dtypes.float32),\n            tf.cast(loss, dtype=tf.dtypes.float32),\n            name='assert_loss'),)):\n          return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        optimizer=_Optimizer())\n\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n\n  def test_train_with_update_ops(self):\n    head = head_lib._regression_head()\n\n    with tf.Graph().as_default():\n      w = tf.Variable(1)\n      update_op = w.assign_add(1)\n      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.UPDATE_OPS,\n                                     update_op)\n\n      t = tf.Variable('')\n      expected_train_result = b'my_train_op'\n\n      def _train_op_fn(loss):\n        del loss\n        return t.assign(expected_train_result)\n\n      spec = head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (41,),), dtype=np.float32),\n          labels=np.array(((43.,), (44.,),), dtype=np.float64),\n          train_op_fn=_train_op_fn)\n\n      with self.cached_session() as sess:\n        _initialize_variables(self, spec.scaffold)\n        sess.run(spec.train_op)\n        w_value, t_value = sess.run([w, t])\n        self.assertEqual(2, w_value)\n        self.assertEqual(expected_train_result, t_value)\n\n  def test_train_summaries_with_head_name(self):\n    head = head_lib._regression_head(name='some_regression_head')\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43.,), (44.,),), dtype=np.float64)\n    features = {'x': np.array(((42.,),), dtype=np.float32)}\n    # loss = (43-45)^2 + (44-41)^2 = 4 + 9 = 13\n    expected_loss = 13\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    # Assert summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      summary_str = sess.run(spec.scaffold.summary_op)\n      _assert_simple_summaries(\n          self,\n          {\n              '{}/some_regression_head'.format(metric_keys.MetricKeys.LOSS):\n                  expected_loss,\n              # loss_mean = loss/2 = 13/2 = 6.5\n              '{}/some_regression_head'\n              .format(metric_keys.MetricKeys.LOSS_MEAN):\n                  6.5,\n          },\n          summary_str)\n\n  def test_train_with_regularization_losses(self):\n    head = head_lib._regression_head(\n        loss_reduction=tf.compat.v1.losses.Reduction.SUM_OVER_BATCH_SIZE)\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43.,), (44.,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42.,),), dtype=np.float32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = ((43-45)^2 + (44-41)^2) / batch_size\n    #                    = (4 + 9) / 2 = 6.5\n    # loss = unregularized_loss + regularization_loss = 8.5\n    expected_loss = 8.5\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        regularization_losses=regularization_losses)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      _assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              metric_keys.MetricKeys.LOSS_REGULARIZATION:\n                  expected_regularization_loss,\n          }, summary_str)\n\n  def test_weighted_multi_example_eval(self):\n    \"\"\"1d label, 3 examples, 1 batch.\"\"\"\n    head = head_lib._regression_head(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,), (44,)), dtype=np.int32)\n    spec = head.create_estimator_spec(\n        features={\n            'x': np.array(((42,), (43,), (44,)), dtype=np.int32),\n            'label_weights': np.array(((1.,), (.1,), (1.5,)), dtype=np.float32),\n        },\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=np.array(((35,), (42,), (45,)), dtype=np.int32))\n\n    # Assert spec contains expected tensors.\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), spec.predictions.keys())\n    self.assertEqual(tf.dtypes.float32, spec.predictions[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,\n                           metric_keys.MetricKeys.PREDICTION_MEAN,\n                           metric_keys.MetricKeys.LABEL_MEAN),\n                          spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[\n          metric_keys.MetricKeys.LOSS_MEAN]\n      predictions, loss, loss_mean = sess.run(\n          (spec.predictions[prediction_key], spec.loss, loss_mean_update_op))\n      self.assertAllClose(logits, predictions)\n      # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6\n      self.assertAllClose(101.6, loss)\n      # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.0769231\n      expected_loss_mean = 39.0769231\n      # Check results of both update (in `loss_mean`) and value ops.\n      self.assertAllClose(expected_loss_mean, loss_mean)\n      self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval())\n\n  def test_weight_with_numeric_column(self):\n    \"\"\"1d label, 3 examples, 1 batch.\"\"\"\n    head = head_lib._regression_head(\n        weight_column=tf.feature_column.numeric_column(\n            'label_weights', normalizer_fn=lambda x: x + 1.))\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,), (44,)), dtype=np.int32)\n    spec = head.create_estimator_spec(\n        features={\n            'x':\n                np.array(((42,), (43,), (44,)), dtype=np.int32),\n            'label_weights':\n                np.array(((0.,), (-0.9,), (0.5,)), dtype=np.float32),\n        },\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=np.array(((35,), (42,), (45,)), dtype=np.int32))\n\n    # Assert loss.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      loss = sess.run(spec.loss)\n      # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6\n      self.assertAllClose(101.6, loss)\n\n  def test_weighted_multi_example_train(self):\n    \"\"\"1d label, 3 examples, 1 batch.\"\"\"\n    head = head_lib._regression_head(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,), (44,)), dtype=np.float32)\n    expected_train_result = b'my_train_op'\n    # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6\n    expected_loss = 101.6\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features={\n            'x': np.array(((42,), (43,), (44,)), dtype=np.float32),\n            'label_weights': np.array(((1.,), (.1,), (1.5,)), dtype=np.float64),\n        },\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=np.array(((35.,), (42.,), (45.,)), dtype=np.float32),\n        train_op_fn=_train_op_fn)\n\n    # Assert spec contains expected tensors.\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), spec.predictions.keys())\n    self.assertEqual(tf.dtypes.float32, spec.predictions[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      _assert_simple_summaries(\n          self,\n          {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.0769231\n              metric_keys.MetricKeys.LOSS_MEAN: 39.0769231,\n          },\n          summary_str)\n\n  def test_train_one_dim_create_loss(self):\n    \"\"\"Tests create_loss with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib._regression_head(weight_column='label_weights')\n    logits = np.array(((45,), (41,), (44,)), dtype=np.float32)\n    x_feature_rank_1 = np.array((42., 43., 44.,), dtype=np.float32)\n    weight_rank_1 = np.array((1., .1, 1.5,), dtype=np.float64)\n    labels_rank_1 = np.array((35., 42., 45.,))\n    # unreduced_loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].\n    expected_unreduced_loss = [[100.], [1.], [1.]]\n    # weights are reshaped to [3, 1] to match logits.\n    expected_weights = [[1.], [.1], [1.5]]\n    # training_loss = 100 * 1 + 1 * .1 + 1.5 * 1 = 101.6\n    expected_training_loss = 101.6\n    features = {'x': x_feature_rank_1, 'label_weights': weight_rank_1}\n    # Create loss.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_training_loss, training_loss.eval())\n      self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())\n      self.assertAllClose(expected_weights, actual_weights.eval())\n\n  def test_train_one_dim(self):\n    \"\"\"Tests train with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib._regression_head(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,), (44,)), dtype=np.float32)\n    expected_train_result = b'my_train_op'\n    # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6\n    expected_loss = 101.6\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    x_feature_rank_1 = np.array((42., 43., 44.,), dtype=np.float32)\n    weight_rank_1 = np.array((1., .1, 1.5,), dtype=np.float64)\n    labels_rank_1 = np.array((35., 42., 45.,))\n    features = {'x': x_feature_rank_1, 'label_weights': weight_rank_1}\n    self.assertEqual((3,), x_feature_rank_1.shape)\n    self.assertEqual((3,), weight_rank_1.shape)\n    self.assertEqual((3,), labels_rank_1.shape)\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1,\n        train_op_fn=_train_op_fn)\n\n    # Assert spec contains expected tensors.\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), spec.predictions.keys())\n    self.assertEqual(tf.dtypes.float32, spec.predictions[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      _assert_simple_summaries(\n          self,\n          {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.0769231\n              metric_keys.MetricKeys.LOSS_MEAN: 39.0769231,\n          },\n          summary_str)\n\n  def test_weighted_multi_value_eval_create_loss(self):\n    \"\"\"3d label, 1 example, 1 batch.\"\"\"\n    head = head_lib._regression_head(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array(((45., 41., 44.),))\n    labels = np.array(((35., 42., 45.),))\n    features = {\n        'x': np.array(((42., 43., 44.),)),\n        'label_weights': np.array(((1., .1, 1.5),))\n    }\n    # Create loss.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      # loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].\n      # weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6\n      self.assertAllClose(101.6, training_loss.eval())\n\n  def test_weighted_multi_value_eval(self):\n    \"\"\"3d label, 1 example, 1 batch.\"\"\"\n    head = head_lib._regression_head(\n        weight_column='label_weights', label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n\n    logits = np.array(((45., 41., 44.),))\n    labels = np.array(((35., 42., 45.),))\n    features = {\n        'x': np.array(((42., 43., 44.),)),\n        'label_weights': np.array(((1., .1, 1.5),))\n    }\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n\n    # Assert spec contains expected tensors.\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), spec.predictions.keys())\n    self.assertEqual(tf.dtypes.float32, spec.predictions[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,\n                           metric_keys.MetricKeys.PREDICTION_MEAN,\n                           metric_keys.MetricKeys.LABEL_MEAN),\n                          spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[\n          metric_keys.MetricKeys.LOSS_MEAN]\n      predictions, loss, loss_mean = sess.run(\n          (spec.predictions[prediction_key], spec.loss, loss_mean_update_op))\n      self.assertAllClose(logits, predictions)\n      # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6\n      self.assertAllClose(101.6, loss)\n      # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.076923\n      expected_loss_mean = 39.076923\n      # Check results of both update (in `loss_mean`) and value ops.\n      self.assertAllClose(expected_loss_mean, loss_mean)\n      self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval())\n\n  def test_weighted_multi_value_train_create_loss(self):\n    \"\"\"3d label, 1 example, 1 batch.\"\"\"\n    head = head_lib._regression_head(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array(((45., 41., 44.),))\n    labels = np.array(((35., 42., 45.),))\n    features = {\n        'x': np.array(((42., 43., 44.),)),\n        'label_weights': np.array(((1., .1, 1.5),))\n    }\n    # Create loss.\n    training_loss = head.create_loss(\n        features=features, mode=ModeKeys.TRAIN, logits=logits, labels=labels)[0]\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      # loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].\n      # weighted sum loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6\n      self.assertAllClose(101.6, training_loss.eval())\n\n  def test_weighted_multi_value_train(self):\n    \"\"\"3d label, 1 example, 1 batch.\"\"\"\n    head = head_lib._regression_head(\n        weight_column='label_weights', label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n\n    logits = np.array(((45., 41., 44.),))\n    labels = np.array(((35., 42., 45.),))\n    expected_train_result = b'my_train_op'\n    # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6\n    expected_loss = 101.6\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    features = {\n        'x': np.array(((42., 43., 44.),)),\n        'label_weights': np.array(((1., .1, 1.5),)),\n    }\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n\n    # Assert spec contains expected tensors.\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), spec.predictions.keys())\n    self.assertEqual(tf.dtypes.float32, spec.predictions[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    _assert_no_hooks(self, spec)\n\n    # Evaluate predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      _assert_simple_summaries(\n          self,\n          {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.076923\n              metric_keys.MetricKeys.LOSS_MEAN: 39.076923,\n          },\n          summary_str)\n\n  def test_weighted_multi_batch_eval(self):\n    \"\"\"1d label, 1 example, 3 batches.\"\"\"\n    head = head_lib._regression_head(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45.,), (41.,), (44.,)))\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'x': np.array(((42.,), (43.,), (44.,))),\n            'label_weights': np.array(((1.,), (.1,), (1.5,))),\n            # 'logits' is not a feature, but we use `numpy_input_fn` to make a\n            # batched version of it, and pop it off before passing to\n            # `create_estimator_spec`.\n            'logits': logits,\n        },\n        y=np.array(((35.,), (42.,), (45.,))),\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    batched_features, batched_labels = input_fn()\n    batched_logits = batched_features.pop('logits')\n    spec = head.create_estimator_spec(\n        features=batched_features,\n        mode=ModeKeys.EVAL,\n        logits=batched_logits,\n        labels=batched_labels,\n        train_op_fn=None)\n\n    # losses = [1*(35-45)^2, .1*(42-41)^2, 1.5*(45-44)^2] = [100, .1, 1.5]\n    # loss = sum(losses) = 100+.1+1.5 = 101.6\n    # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.076923\n    expected_metrics = {\n        metric_keys.MetricKeys.LOSS_MEAN:\n            39.076923,\n        metric_keys.MetricKeys.PREDICTION_MEAN:\n            (45 + 41 * 0.1 + 44 * 1.5) / 2.6,\n        metric_keys.MetricKeys.LABEL_MEAN: (35 + 42 * 0.1 + 45 * 1.5) / 2.6,\n    }\n\n    # Assert spec contains expected tensors.\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    _assert_no_hooks(self, spec)\n\n    with self.cached_session() as sess:\n      # Finalize graph and initialize variables.\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      tf.compat.v1.train.queue_runner.start_queue_runners()\n\n      # Run tensors for `steps` steps.\n      steps = len(logits)\n      results = tuple([\n          sess.run((\n              spec.loss,\n              # The `[1]` gives us the metric update op.\n              {k: spec.eval_metric_ops[k][1]\n               for k in spec.eval_metric_ops}))\n          for _ in range(steps)\n      ])\n\n      # Assert losses and metrics.\n      self.assertAllClose((100, .1, 1.5), [r[0] for r in results])\n      # For metrics, check results of both update (in `results`) and value ops.\n      # Note: we only check the result of the last step for streaming metrics.\n      self.assertAllClose(expected_metrics, results[steps - 1][1])\n      self.assertAllClose(\n          expected_metrics,\n          {k: spec.eval_metric_ops[k][0].eval() for k in spec.eval_metric_ops})\n\n  def test_weighted_multi_batch_train(self):\n    \"\"\"1d label, 1 example, 3 batches.\"\"\"\n    head = head_lib._regression_head(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45.,), (41.,), (44.,)))\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'x': np.array(((42.,), (43.,), (44.,))),\n            'label_weights': np.array(((1.,), (.1,), (1.5,))),\n            # 'logits' is not a feature, but we use `numpy_input_fn` to make a\n            # batched version of it, and pop it off before passing to\n            # `create_estimator_spec`.\n            'logits': logits,\n        },\n        y=np.array(((35.,), (42.,), (45.,))),\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    batched_features, batched_labels = input_fn()\n    batched_logits = batched_features.pop('logits')\n    spec = head.create_estimator_spec(\n        features=batched_features,\n        mode=ModeKeys.TRAIN,\n        logits=batched_logits,\n        labels=batched_labels,\n        train_op_fn=lambda loss: loss * -7.)\n\n    # Assert spec contains expected tensors.\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertIsNotNone(spec.train_op)\n\n    with self.cached_session() as sess:\n      # Finalize graph and initialize variables.\n      _initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      tf.compat.v1.train.queue_runner.start_queue_runners()\n\n      results = tuple(\n          [sess.run((spec.loss, spec.train_op)) for _ in range(len(logits))])\n\n      # losses = [1*(35-45)^2, .1*(42-41)^2, 1.5*(45-44)^2] = [100, .1, 1.5]\n      expected_losses = np.array((100, .1, 1.5))\n      self.assertAllClose(expected_losses, [r[0] for r in results])\n      self.assertAllClose(expected_losses * -7., [r[1] for r in results])\n\n  def test_multi_dim_weighted_train_create_loss(self):\n    \"\"\"Logits, labels of shape [2, 2, 3], weight shape [2, 2].\"\"\"\n    label_dimension = 3\n    head = head_lib._regression_head(\n        weight_column='label_weights', label_dimension=label_dimension)\n    logits = np.array([[[00., 01., 02.], [10., 11., 12.]],\n                       [[20., 21., 22.], [30., 31., 32.]]])\n    labels = np.array([[[01., 02., 03.], [12., 13., 14.]],\n                       [[23., 24., 25.], [34., 35., 36.]]])\n    weights = np.array([[1., 1.5], [2., 2.5]])\n    expected_unreduced_loss = [[[1., 1., 1.], [4., 4., 4.]],\n                               [[9., 9., 9.], [16., 16., 16.]]]\n    expected_training_loss = np.sum(\n        np.array([[[1. * x for x in [1., 1., 1.]],\n                   [1.5 * x for x in [4., 4., 4.]]],\n                  [[2. * x for x in [9., 9., 9.]],\n                   [2.5 * x for x in [16., 16., 16.]]]]))\n    # Weights are expanded to [2, 2, 1] to match logits.\n    expected_weights = [[[1.], [1.5]], [[2.], [2.5]]]\n    # Create loss.\n    training_loss, unreduced_loss, actual_weights, _ = head.create_loss(\n        features={'label_weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_training_loss, training_loss.eval())\n      self.assertAllClose(expected_unreduced_loss, unreduced_loss.eval())\n      self.assertAllClose(expected_weights, actual_weights.eval())\n\n  def test_multi_dim_weighted_train(self):\n    \"\"\"Logits, labels of shape [2, 2, 3], weight shape [2, 2].\"\"\"\n    head = head_lib._regression_head(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array([[[00., 01., 02.], [10., 11., 12.]],\n                       [[20., 21., 22.], [30., 31., 32.]]])\n    labels = np.array([[[01., 02., 03.], [12., 13., 14.]],\n                       [[23., 24., 25.], [34., 35., 36.]]])\n    expected_train_result = b'my_train_op'\n    features = {\n        'label_weights': np.array([[1., 1.5], [2., 2.5]]),\n    }\n    # loss = 1*3*1^2 + 1.5*3*2^2 + 2*3*3^2 +2.5*3*4^2 = 195\n    expected_loss = 195.\n\n    # Create estimator spec.\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_loss, spec.loss.eval())\n\n  def test_multi_dim_train_weights_wrong_inner_dim(self):\n    \"\"\"Logits, labels of shape [2, 2, 3], weight shape [2, 1].\"\"\"\n    head = head_lib._regression_head(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array([[[00., 01., 02.], [10., 11., 12.]],\n                       [[20., 21., 22.], [30., 31., 32.]]])\n    labels = np.array([[[01., 02., 03.], [12., 13., 14.]],\n                       [[23., 24., 25.], [34., 35., 36.]]])\n    features = {\n        'label_weights': np.array([[1.], [2]]),\n    }\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\] \\[2 2 3\\] \\[weights_shape: \\] \\[2 1\\]'):\n        spec.loss.eval()\n\n  def test_multi_dim_train_weights_wrong_outer_dim(self):\n    \"\"\"Logits, labels of shape [2, 2, 3], weight shape [2, 2, 2].\"\"\"\n    head = head_lib._regression_head(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array([[[00., 01., 02.], [10., 11., 12.]],\n                       [[20., 21., 22.], [30., 31., 32.]]])\n    labels = np.array([[[01., 02., 03.], [12., 13., 14.]],\n                       [[23., 24., 25.], [34., 35., 36.]]])\n    weights_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    features = {\n        'label_weights': weights_placeholder,\n    }\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn)\n    with self.cached_session():\n      _initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\]\\s\\[2 2 3\\]\\s\\[weights_shape: \\]\\s\\[2 2 2\\]'):\n        spec.loss.eval({\n            weights_placeholder:\n                np.array([[[1., 1.1], [1.5, 1.6]], [[2., 2.1], [2.5, 2.6]]])\n        })\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/kmeans.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"A canned Estimator for k-means clustering.\"\"\"\n\n# TODO(ccolby): Move clustering_ops.py into this file and streamline the code.\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport time\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import clustering_ops\nfrom tensorflow.python.ops import control_flow_ops\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.export import export_output\n\n\nclass _LossRelativeChangeHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Stops when the change in loss goes below a tolerance.\"\"\"\n\n  def __init__(self, loss_tensor, tolerance):\n    \"\"\"Creates a _LossRelativeChangeHook.\n\n    Args:\n      loss_tensor: A scalar tensor of the loss value.\n      tolerance: A relative tolerance of loss change between iterations.\n    \"\"\"\n    self._loss_tensor = loss_tensor\n    self._tolerance = tolerance\n    self._prev_loss = None\n\n  def before_run(self, run_context):\n    del run_context  # unused\n    return tf.compat.v1.train.SessionRunArgs(self._loss_tensor)\n\n  def after_run(self, run_context, run_values):\n    loss = run_values.results\n    assert loss is not None\n    if self._prev_loss:\n      relative_change = (\n          abs(loss - self._prev_loss) / (1 + abs(self._prev_loss)))\n      if relative_change < self._tolerance:\n        run_context.request_stop()\n    self._prev_loss = loss\n\n\nclass _InitializeClustersHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Initializes the cluster centers.\n\n  The chief repeatedly invokes an initialization op until all cluster centers\n  are initialized. The workers wait for the initialization phase to complete.\n  \"\"\"\n\n  def __init__(self, init_op, is_initialized_var, is_chief):\n    \"\"\"Creates an _InitializeClustersHook.\n\n    Args:\n      init_op: An op that, when run, will choose some initial cluster centers.\n        This op may need to be run multiple times to choose all the centers.\n      is_initialized_var: A boolean variable reporting whether all initial\n        centers have been chosen.\n      is_chief: A boolean specifying whether this task is the chief.\n    \"\"\"\n    self._init_op = init_op\n    self._is_initialized_var = is_initialized_var\n    self._is_chief = is_chief\n\n  def after_create_session(self, session, coord):\n    del coord  # unused\n    assert self._init_op.graph is tf.compat.v1.get_default_graph()\n    assert self._is_initialized_var.graph is self._init_op.graph\n    while True:\n      try:\n        if session.run(self._is_initialized_var):\n          break\n        elif self._is_chief:\n          session.run(self._init_op)\n        else:\n          time.sleep(1)\n      except RuntimeError as e:\n        tf.compat.v1.logging.info(e)\n\n\ndef _parse_features_if_necessary(features, feature_columns):\n  \"\"\"Helper function to convert the input points into a usable format.\n\n  Args:\n    features: The input features.\n    feature_columns: An optionable iterable containing all the feature columns\n      used by the model. All items in the set should be feature column instances\n      that can be passed to `tf.feature_column.input_layer`. If this is None,\n      all features will be used.\n\n  Returns:\n    If `features` is a dict of `k` features (optionally filtered by\n    `feature_columns`), each of which is a vector of `n` scalars, the return\n    value is a Tensor of shape `(n, k)` representing `n` input points, where the\n    items in the `k` dimension are sorted lexicographically by `features` key.\n    If `features` is not a dict, it is returned unmodified.\n  \"\"\"\n  if not isinstance(features, dict):\n    return features\n\n  if feature_columns:\n    return tf.compat.v1.feature_column.input_layer(features, feature_columns)\n\n  keys = sorted(features.keys())\n  with ops.colocate_with(features[keys[0]]):\n    return tf.concat([features[k] for k in keys], axis=1)\n\n\nclass _ModelFn(object):\n  \"\"\"Model function for the estimator.\"\"\"\n\n  def __init__(self, num_clusters, initial_clusters, distance_metric, seed,\n               use_mini_batch, mini_batch_steps_per_iteration,\n               kmeans_plus_plus_num_retries, relative_tolerance,\n               feature_columns):\n    self._num_clusters = num_clusters\n    self._initial_clusters = initial_clusters\n    self._distance_metric = distance_metric\n    self._seed = seed\n    self._use_mini_batch = use_mini_batch\n    self._mini_batch_steps_per_iteration = mini_batch_steps_per_iteration\n    self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries\n    self._relative_tolerance = relative_tolerance\n    self._feature_columns = feature_columns\n\n  def model_fn(self, features, mode, config):\n    \"\"\"Model function for the estimator.\n\n    Note that this does not take a `labels` arg. This works, but `input_fn` must\n    return either `features` or, equivalently, `(features, None)`.\n\n    Args:\n      features: The input points. See `tf.estimator.Estimator`.\n      mode: See `tf.estimator.Estimator`.\n      config: See `tf.estimator.Estimator`.\n\n    Returns:\n      A `tf.estimator.EstimatorSpec` (see `tf.estimator.Estimator`) specifying\n      this behavior:\n        * `train_op`: Execute one mini-batch or full-batch run of Lloyd's\n             algorithm.\n        * `loss`: The sum of the squared distances from each input point to its\n             closest center.\n        * `eval_metric_ops`: Maps `SCORE` to `loss`.\n        * `predictions`: Maps `ALL_DISTANCES` to the distance from each input\n             point to each cluster center; maps `CLUSTER_INDEX` to the index of\n             the closest cluster center for each input point.\n    \"\"\"\n    # input_points is a single Tensor. Therefore, the sharding functionality\n    # in clustering_ops is unused, and some of the values below are lists of a\n    # single item.\n    input_points = _parse_features_if_necessary(features, self._feature_columns)\n\n    # Let N = the number of input_points.\n    # all_distances: A list of one matrix of shape (N, num_clusters). Each value\n    #   is the distance from an input point to a cluster center.\n    # model_predictions: A list of one vector of shape (N). Each value is the\n    #   cluster id of an input point.\n    # losses: Similar to cluster_idx but provides the distance to the cluster\n    #   center.\n    # is_initialized: scalar indicating whether the initial cluster centers\n    #   have been chosen; see init_op.\n    # init_op: an op to choose the initial cluster centers. A single worker\n    #   repeatedly executes init_op until is_initialized becomes True.\n    # training_op: an op that runs an iteration of training, either an entire\n    #   Lloyd iteration or a mini-batch of a Lloyd iteration. Multiple workers\n    #   may execute this op, but only after is_initialized becomes True.\n    (all_distances, model_predictions, losses, is_initialized, init_op,\n     training_op) = clustering_ops.KMeans(\n         inputs=input_points,\n         num_clusters=self._num_clusters,\n         initial_clusters=self._initial_clusters,\n         distance_metric=self._distance_metric,\n         use_mini_batch=self._use_mini_batch,\n         mini_batch_steps_per_iteration=self._mini_batch_steps_per_iteration,\n         random_seed=self._seed,\n         kmeans_plus_plus_num_retries=self._kmeans_plus_plus_num_retries\n     ).training_graph()\n\n    loss = tf.math.reduce_sum(losses)\n    tf.compat.v1.summary.scalar('loss/raw', loss)\n\n    incr_step = tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(), 1)\n    training_op = control_flow_ops.with_dependencies([training_op, incr_step],\n                                                     loss)\n\n    training_hooks = [\n        _InitializeClustersHook(init_op, is_initialized, config.is_chief)\n    ]\n    if self._relative_tolerance is not None:\n      training_hooks.append(\n          _LossRelativeChangeHook(loss, self._relative_tolerance))\n\n    export_outputs = {\n        KMeansClustering.ALL_DISTANCES:\n            export_output.PredictOutput(all_distances[0]),\n        KMeansClustering.CLUSTER_INDEX:\n            export_output.PredictOutput(model_predictions[0]),\n        tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY:\n            export_output.PredictOutput(model_predictions[0])\n    }\n\n    return model_fn_lib.EstimatorSpec(\n        mode=mode,\n        predictions={\n            KMeansClustering.ALL_DISTANCES: all_distances[0],\n            KMeansClustering.CLUSTER_INDEX: model_predictions[0],\n        },\n        loss=loss,\n        train_op=training_op,\n        eval_metric_ops={\n            KMeansClustering.SCORE: tf.compat.v1.metrics.mean(loss)\n        },\n        training_hooks=training_hooks,\n        export_outputs=export_outputs)\n\n\n# TODO(agarwal,ands): support sharded input.\n@estimator_export(v1=['estimator.experimental.KMeans'])\nclass KMeansClustering(estimator.Estimator):\n  \"\"\"An Estimator for K-Means clustering.\n\n  Example:\n  ```\n  import numpy as np\n  import tensorflow as tf\n\n  num_points = 100\n  dimensions = 2\n  points = np.random.uniform(0, 1000, [num_points, dimensions])\n\n  def input_fn():\n    return tf.compat.v1.train.limit_epochs(\n        tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1)\n\n  num_clusters = 5\n  kmeans = tf.compat.v1.estimator.experimental.KMeans(\n      num_clusters=num_clusters, use_mini_batch=False)\n\n  # train\n  num_iterations = 10\n  previous_centers = None\n  for _ in xrange(num_iterations):\n    kmeans.train(input_fn)\n    cluster_centers = kmeans.cluster_centers()\n    if previous_centers is not None:\n      print 'delta:', cluster_centers - previous_centers\n    previous_centers = cluster_centers\n    print 'score:', kmeans.score(input_fn)\n  print 'cluster centers:', cluster_centers\n\n  # map the input points to their clusters\n  cluster_indices = list(kmeans.predict_cluster_index(input_fn))\n  for i, point in enumerate(points):\n    cluster_index = cluster_indices[i]\n    center = cluster_centers[cluster_index]\n    print 'point:', point, 'is in cluster', cluster_index, 'centered at', center\n  ```\n\n  The `SavedModel` saved by the `export_saved_model` method does not include the\n  cluster centers. However, the cluster centers may be retrieved by the\n  latest checkpoint saved during training. Specifically,\n  ```\n  kmeans.cluster_centers()\n  ```\n  is equivalent to\n  ```\n  tf.train.load_variable(\n      kmeans.model_dir, KMeansClustering.CLUSTER_CENTERS_VAR_NAME)\n  ```\n  \"\"\"\n\n  # Valid values for the distance_metric constructor argument.\n  SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE\n  COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE\n\n  # Values for initial_clusters constructor argument.\n  RANDOM_INIT = clustering_ops.RANDOM_INIT\n  KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT\n\n  # Metric returned by evaluate(): The sum of the squared distances from each\n  # input point to its closest center.\n  SCORE = 'score'\n\n  # Keys returned by predict().\n  # ALL_DISTANCES: The distance from each input point to each cluster center.\n  # CLUSTER_INDEX: The index of the closest cluster center for each input point.\n  CLUSTER_INDEX = 'cluster_index'\n  ALL_DISTANCES = 'all_distances'\n\n  # Variable name used by cluster_centers().\n  CLUSTER_CENTERS_VAR_NAME = clustering_ops.CLUSTERS_VAR_NAME\n\n  def __init__(self,\n               num_clusters,\n               model_dir=None,\n               initial_clusters=RANDOM_INIT,\n               distance_metric=SQUARED_EUCLIDEAN_DISTANCE,\n               seed=None,\n               use_mini_batch=True,\n               mini_batch_steps_per_iteration=1,\n               kmeans_plus_plus_num_retries=2,\n               relative_tolerance=None,\n               config=None,\n               feature_columns=None):\n    r\"\"\"Creates an Estimator for running KMeans training and inference.\n\n    This Estimator implements the following variants of the K-means algorithm:\n\n    If `use_mini_batch` is False, it runs standard full batch K-means. Each\n    training step runs a single iteration of K-Means and must process the full\n    input at once. To run in this mode, the `input_fn` passed to `train` must\n    return the entire input dataset.\n\n    If `use_mini_batch` is True, it runs a generalization of the mini-batch\n    K-means algorithm. It runs multiple iterations, where each iteration is\n    composed of `mini_batch_steps_per_iteration` steps. Each training step\n    accumulates the contribution from one mini-batch into temporary storage.\n    Every `mini_batch_steps_per_iteration` steps, the cluster centers are\n    updated and the temporary storage cleared for the next iteration.\n    For example: the entire dataset contains 64k examples, where the batch size\n    is 64. User can choose mini_batch_steps_per_iteration = 100 to run 10% of\n    the entire data every iteration in order to update the cluster centers.\n    Note that:\n      * If `mini_batch_steps_per_iteration=1`, the algorithm reduces to the\n        standard K-means mini-batch algorithm.\n      * If `mini_batch_steps_per_iteration = num_inputs / batch_size`, the\n        algorithm becomes an asynchronous version of the full-batch algorithm.\n        However, there is no guarantee by this implementation that each input\n        is seen exactly once per iteration. Also, different updates are applied\n        asynchronously without locking. So this asynchronous version may not\n        behave exactly like a full-batch version.\n\n    Args:\n      num_clusters: An integer tensor specifying the number of clusters. This\n        argument is ignored if `initial_clusters` is a tensor or numpy array.\n      model_dir: The directory to save the model results and log files.\n      initial_clusters: Specifies how the initial cluster centers are chosen.\n        One of the following: * a tensor or numpy array with the initial cluster\n          centers. * a callable `f(inputs, k)` that selects and returns up to\n          `k` centers from an input batch. `f` is free to return any number of\n          centers from `0` to `k`. It will be invoked on successive input\n          batches as necessary until all `num_clusters` centers are chosen.\n        * `KMeansClustering.RANDOM_INIT`: Choose centers randomly from an input\n          batch. If the batch size is less than `num_clusters` then the entire\n          batch is chosen to be initial cluster centers and the remaining\n          centers are chosen from successive input batches.\n        * `KMeansClustering.KMEANS_PLUS_PLUS_INIT`: Use kmeans++ to choose\n          centers from the first input batch. If the batch size is less than\n          `num_clusters`, a TensorFlow runtime error occurs.\n      distance_metric: The distance metric used for clustering. One of:\n        * `KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`: Euclidean distance\n          between vectors `u` and `v` is defined as \\\\(||u - v||_2\\\\) which is\n          the square root of the sum of the absolute squares of the elements'\n          difference.\n        * `KMeansClustering.COSINE_DISTANCE`: Cosine distance between vectors\n          `u` and `v` is defined as \\\\(1 - (u . v) / (||u||_2 ||v||_2)\\\\).\n      seed: Python integer. Seed for PRNG used to initialize centers.\n      use_mini_batch: A boolean specifying whether to use the mini-batch k-means\n        algorithm. See explanation above.\n      mini_batch_steps_per_iteration: The number of steps after which the\n        updated cluster centers are synced back to a master copy. Used only if\n        `use_mini_batch=True`. See explanation above.\n      kmeans_plus_plus_num_retries: For each point that is sampled during\n        kmeans++ initialization, this parameter specifies the number of\n        additional points to draw from the current distribution before selecting\n        the best. If a negative value is specified, a heuristic is used to\n        sample `O(log(num_to_sample))` additional points. Used only if\n        `initial_clusters=KMeansClustering.KMEANS_PLUS_PLUS_INIT`.\n      relative_tolerance: A relative tolerance of change in the loss between\n        iterations. Stops learning if the loss changes less than this amount.\n        This may not work correctly if `use_mini_batch=True`.\n      config: See `tf.estimator.Estimator`.\n      feature_columns: An optionable iterable containing all the feature columns\n        used by the model. All items in the set should be feature column\n        instances that can be passed to `tf.feature_column.input_layer`. If this\n        is None, all features will be used.\n\n    Raises:\n      ValueError: An invalid argument was passed to `initial_clusters` or\n        `distance_metric`.\n    \"\"\"\n    if isinstance(initial_clusters, str) and initial_clusters not in [\n        KMeansClustering.RANDOM_INIT, KMeansClustering.KMEANS_PLUS_PLUS_INIT\n    ]:\n      raise ValueError(\"Unsupported initialization algorithm '%s'\" %\n                       initial_clusters)\n    if distance_metric not in [\n        KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE,\n        KMeansClustering.COSINE_DISTANCE\n    ]:\n      raise ValueError(\"Unsupported distance metric '%s'\" % distance_metric)\n    self._distance_metric = distance_metric\n    super(KMeansClustering, self).__init__(\n        model_fn=_ModelFn(num_clusters, initial_clusters, distance_metric, seed,\n                          use_mini_batch, mini_batch_steps_per_iteration,\n                          kmeans_plus_plus_num_retries, relative_tolerance,\n                          feature_columns).model_fn,\n        model_dir=model_dir,\n        config=config)\n\n  def _predict_one_key(self, input_fn, predict_key):\n    for result in self.predict(input_fn=input_fn, predict_keys=[predict_key]):\n      yield result[predict_key]\n\n  def predict_cluster_index(self, input_fn):\n    \"\"\"Finds the index of the closest cluster center to each input point.\n\n    Args:\n      input_fn: Input points. See `tf.estimator.Estimator.predict`.\n\n    Yields:\n      The index of the closest cluster center for each input point.\n    \"\"\"\n    for index in self._predict_one_key(input_fn,\n                                       KMeansClustering.CLUSTER_INDEX):\n      yield index\n\n  def score(self, input_fn):\n    \"\"\"Returns the sum of squared distances to nearest clusters.\n\n    Note that this function is different from the corresponding one in sklearn\n    which returns the negative sum.\n\n    Args:\n      input_fn: Input points. See `tf.estimator.Estimator.evaluate`. Only one\n        batch is retrieved.\n\n    Returns:\n      The sum of the squared distance from each point in the first batch of\n      inputs to its nearest cluster center.\n    \"\"\"\n    return self.evaluate(input_fn=input_fn, steps=1)[KMeansClustering.SCORE]\n\n  def transform(self, input_fn):\n    \"\"\"Transforms each input point to its distances to all cluster centers.\n\n    Note that if `distance_metric=KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`,\n    this\n    function returns the squared Euclidean distance while the corresponding\n    sklearn function returns the Euclidean distance.\n\n    Args:\n      input_fn: Input points. See `tf.estimator.Estimator.predict`.\n\n    Yields:\n      The distances from each input point to each cluster center.\n    \"\"\"\n    for distances in self._predict_one_key(input_fn,\n                                           KMeansClustering.ALL_DISTANCES):\n      if self._distance_metric == KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE:\n        yield np.sqrt(distances)\n      else:\n        yield distances\n\n  def cluster_centers(self):\n    \"\"\"Returns the cluster centers.\"\"\"\n    return self.get_variable_value(KMeansClustering.CLUSTER_CENTERS_VAR_NAME)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/kmeans_test.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for KMeans.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\nimport time\n\nimport numpy as np\nfrom sklearn.cluster import KMeans as SklearnKMeans\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.platform import benchmark\nfrom tensorflow.python.platform import flags\nfrom tensorflow_estimator.python.estimator import run_config\nfrom tensorflow_estimator.python.estimator.canned import kmeans as kmeans_lib\n\nFLAGS = flags.FLAGS\n\n\ndef normalize(x):\n  return x / np.sqrt(np.sum(x * x, axis=-1, keepdims=True))\n\n\ndef cosine_similarity(x, y):\n  return np.dot(normalize(x), np.transpose(normalize(y)))\n\n\ndef make_random_centers(num_centers, num_dims, center_norm=500):\n  return np.round(\n      np.random.rand(num_centers, num_dims).astype(np.float32) * center_norm)\n\n\ndef make_random_points(centers, num_points, max_offset=20):\n  num_centers, num_dims = centers.shape\n  assignments = np.random.choice(num_centers, num_points)\n  offsets = np.round(\n      np.random.randn(num_points, num_dims).astype(np.float32) * max_offset)\n  return (centers[assignments] + offsets, assignments,\n          np.add.reduce(offsets * offsets, 1))\n\n\nclass KMeansTestBase(tf.test.TestCase):\n\n  def input_fn(self,\n               batch_size=None,\n               points=None,\n               randomize=None,\n               num_epochs=None):\n    \"\"\"Returns an input_fn that randomly selects batches from given points.\"\"\"\n    batch_size = batch_size or self.batch_size\n    points = points if points is not None else self.points\n    num_points = points.shape[0]\n    if randomize is None:\n      randomize = (\n          self.use_mini_batch and self.mini_batch_steps_per_iteration <= 1)\n\n    def _fn():\n      x = tf.constant(points)\n      if batch_size == num_points:\n        return tf.compat.v1.train.limit_epochs(x, num_epochs=num_epochs), None\n      if randomize:\n        indices = tf.random.uniform(\n            tf.constant([batch_size]),\n            minval=0,\n            maxval=num_points - 1,\n            dtype=tf.dtypes.int32,\n            seed=10)\n      else:\n        # We need to cycle through the indices sequentially. We create a queue\n        # to maintain the list of indices.\n        q = tf.queue.FIFOQueue(num_points, tf.dtypes.int32, ())\n\n        # Conditionally initialize the Queue.\n        def _init_q():\n          with tf.control_dependencies([q.enqueue_many(tf.range(num_points))]):\n            return tf.no_op()\n\n        init_q = tf.compat.v1.cond(q.size() <= 0, _init_q, tf.no_op)\n        with tf.control_dependencies([init_q]):\n          offsets = q.dequeue_many(batch_size)\n          with tf.control_dependencies([q.enqueue_many(offsets)]):\n            indices = tf.identity(offsets)\n      batch = tf.compat.v1.gather(x, indices)\n      return (tf.compat.v1.train.limit_epochs(batch,\n                                              num_epochs=num_epochs), None)\n\n    return _fn\n\n  @staticmethod\n  def config(tf_random_seed):\n    return run_config.RunConfig().replace(tf_random_seed=tf_random_seed)\n\n  @property\n  def initial_clusters(self):\n    return kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT\n\n  @property\n  def batch_size(self):\n    return self.num_points\n\n  @property\n  def use_mini_batch(self):\n    return False\n\n  @property\n  def mini_batch_steps_per_iteration(self):\n    return 1\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass KMeansTest(KMeansTestBase):\n\n  def setUp(self):\n    np.random.seed(3)\n    self.num_centers = 5\n    self.num_dims = 2\n    self.num_points = 1000\n    self.true_centers = make_random_centers(self.num_centers, self.num_dims)\n    self.points, _, self.scores = make_random_points(self.true_centers,\n                                                     self.num_points)\n    self.true_score = np.add.reduce(self.scores)\n\n  def _kmeans(self, relative_tolerance=None):\n    return kmeans_lib.KMeansClustering(\n        self.num_centers,\n        initial_clusters=self.initial_clusters,\n        distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE,\n        use_mini_batch=self.use_mini_batch,\n        mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration,\n        seed=24,\n        relative_tolerance=relative_tolerance)\n\n  def test_clusters(self):\n    kmeans = self._kmeans()\n    kmeans.train(input_fn=self.input_fn(), steps=1)\n    clusters = kmeans.cluster_centers()\n    self.assertAllEqual(list(clusters.shape), [self.num_centers, self.num_dims])\n\n  def test_fit(self):\n    kmeans = self._kmeans()\n    kmeans.train(input_fn=self.input_fn(), steps=1)\n    score1 = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points))\n    steps = 10 * self.num_points // self.batch_size\n    kmeans.train(input_fn=self.input_fn(), steps=steps)\n    score2 = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points))\n    self.assertTrue(score1 > score2)\n    self.assertNear(self.true_score, score2, self.true_score * 0.05)\n\n  def test_monitor(self):\n    if self.use_mini_batch:\n      # We don't test for use_mini_batch case since the loss value can be noisy.\n      return\n    kmeans = kmeans_lib.KMeansClustering(\n        self.num_centers,\n        initial_clusters=self.initial_clusters,\n        distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE,\n        use_mini_batch=self.use_mini_batch,\n        mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration,\n        config=self.config(14),\n        seed=12,\n        relative_tolerance=1e-4)\n\n    kmeans.train(\n        input_fn=self.input_fn(),\n        # Force it to train until the relative tolerance monitor stops it.\n        steps=None)\n    score = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points))\n    self.assertNear(self.true_score, score, self.true_score * 0.01)\n\n  def _infer_helper(self, kmeans, clusters, num_points):\n    points, true_assignments, true_offsets = make_random_points(\n        clusters, num_points)\n    input_fn = self.input_fn(batch_size=num_points, points=points, num_epochs=1)\n    # Test predict\n    assignments = list(kmeans.predict_cluster_index(input_fn))\n    self.assertAllEqual(assignments, true_assignments)\n\n    # Test score\n    score = kmeans.score(input_fn=lambda: (tf.constant(points), None))\n    self.assertNear(score, np.sum(true_offsets), 0.01 * score)\n\n    # Test transform\n    transform = list(kmeans.transform(input_fn))\n    true_transform = np.sqrt(\n        np.maximum(\n            0,\n            np.sum(np.square(points), axis=1, keepdims=True) -\n            2 * np.dot(points, np.transpose(clusters)) +\n            np.transpose(np.sum(np.square(clusters), axis=1, keepdims=True))))\n    self.assertAllClose(transform, true_transform, rtol=0.05, atol=10)\n\n  def test_infer(self):\n    kmeans = self._kmeans()\n    # Make a call to fit to initialize the cluster centers.\n    max_steps = 1\n    kmeans.train(input_fn=self.input_fn(), max_steps=max_steps)\n    clusters = kmeans.cluster_centers()\n\n    # Run inference on small datasets.\n    self._infer_helper(kmeans, clusters, 10)\n    self._infer_helper(kmeans, clusters, 1)\n\n  def _parse_feature_dict_helper(self, features, parsed_feature_dict):\n    # Perform a sanity check.\n    self.assertEqual(features.shape, parsed_feature_dict.shape)\n    self.assertEqual(features.dtype, parsed_feature_dict.dtype)\n    # Then check that running the tensor yields the original list of points.\n    with self.cached_session() as sess:\n      parsed_points = sess.run(parsed_feature_dict)\n      self.assertAllEqual(self.points, parsed_points)\n\n  def test_parse_features(self):\n    \"\"\"Tests the various behaviours of kmeans._parse_features_if_necessary.\"\"\"\n\n    # No-op if a tensor is passed in.\n    features = tf.constant(self.points)\n    parsed_features = kmeans_lib._parse_features_if_necessary(features, None)\n    self.assertAllEqual(features, parsed_features)\n\n    # All values from a feature dict are transformed into a tensor.\n    feature_dict = {\n        'x': [[point[0]] for point in self.points],\n        'y': [[point[1]] for point in self.points]\n    }\n    parsed_feature_dict = kmeans_lib._parse_features_if_necessary(\n        feature_dict, None)\n    self._parse_feature_dict_helper(features, parsed_feature_dict)\n\n    # Only the feature_columns of a feature dict are transformed into a tensor.\n    feature_dict_with_extras = {\n        'foo': 'bar',\n        'x': [[point[0]] for point in self.points],\n        'baz': {'fizz': 'buzz'},\n        'y': [[point[1]] for point in self.points]\n    }\n    feature_columns = [\n        tf.feature_column.numeric_column(key='x'),\n        tf.feature_column.numeric_column(key='y')\n    ]\n    parsed_feature_dict = kmeans_lib._parse_features_if_necessary(\n        feature_dict_with_extras, feature_columns)\n    self._parse_feature_dict_helper(features, parsed_feature_dict)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass KMeansTestMultiStageInit(KMeansTestBase):\n\n  def test_random(self):\n    points = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]],\n                      dtype=np.float32)\n    kmeans = kmeans_lib.KMeansClustering(\n        num_clusters=points.shape[0],\n        initial_clusters=kmeans_lib.KMeansClustering.RANDOM_INIT,\n        distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE,\n        use_mini_batch=True,\n        mini_batch_steps_per_iteration=100,\n        seed=24,\n        relative_tolerance=None)\n    kmeans.train(\n        input_fn=self.input_fn(batch_size=1, points=points, randomize=False),\n        steps=1)\n    clusters = kmeans.cluster_centers()\n    self.assertAllEqual(points, clusters)\n\n  def test_kmeans_plus_plus_batch_just_right(self):\n    points = np.array([[1, 2]], dtype=np.float32)\n    kmeans = kmeans_lib.KMeansClustering(\n        num_clusters=points.shape[0],\n        initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT,\n        distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE,\n        use_mini_batch=True,\n        mini_batch_steps_per_iteration=100,\n        seed=24,\n        relative_tolerance=None)\n    kmeans.train(\n        input_fn=self.input_fn(batch_size=1, points=points, randomize=False),\n        steps=1)\n    clusters = kmeans.cluster_centers()\n    self.assertAllEqual(points, clusters)\n\n  def test_kmeans_plus_plus_batch_too_small(self):\n    points = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]],\n                      dtype=np.float32)\n    kmeans = kmeans_lib.KMeansClustering(\n        num_clusters=points.shape[0],\n        initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT,\n        distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE,\n        use_mini_batch=True,\n        mini_batch_steps_per_iteration=100,\n        seed=24,\n        relative_tolerance=None)\n    with self.assertRaisesOpError(AssertionError):\n      kmeans.train(\n          input_fn=self.input_fn(batch_size=4, points=points, randomize=False),\n          steps=1)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass MiniBatchKMeansTest(KMeansTest):\n\n  @property\n  def batch_size(self):\n    return 50\n\n  @property\n  def use_mini_batch(self):\n    return True\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass FullBatchAsyncKMeansTest(KMeansTest):\n\n  @property\n  def batch_size(self):\n    return 50\n\n  @property\n  def use_mini_batch(self):\n    return True\n\n  @property\n  def mini_batch_steps_per_iteration(self):\n    return self.num_points // self.batch_size\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass KMeansCosineDistanceTest(KMeansTestBase):\n\n  def setUp(self):\n    self.points = np.array([[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2],\n                            [0.1, 2.5], [0.2, 2], [0.1, 3], [0.2, 4]],\n                           dtype=np.float32)\n    self.num_points = self.points.shape[0]\n    self.true_centers = np.array([\n        normalize(\n            np.mean(normalize(self.points)[0:4, :], axis=0, keepdims=True))[0],\n        normalize(\n            np.mean(normalize(self.points)[4:, :], axis=0, keepdims=True))[0]\n    ],\n                                 dtype=np.float32)\n    self.true_assignments = np.array([0] * 4 + [1] * 4)\n    self.true_score = len(self.points) - np.tensordot(\n        normalize(self.points), self.true_centers[self.true_assignments])\n\n    self.num_centers = 2\n    self.kmeans = kmeans_lib.KMeansClustering(\n        self.num_centers,\n        initial_clusters=kmeans_lib.KMeansClustering.RANDOM_INIT,\n        distance_metric=kmeans_lib.KMeansClustering.COSINE_DISTANCE,\n        use_mini_batch=self.use_mini_batch,\n        mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration,\n        config=self.config(3))\n\n  def test_fit(self):\n    max_steps = 10 * self.num_points // self.batch_size\n    self.kmeans.train(input_fn=self.input_fn(), max_steps=max_steps)\n    centers = normalize(self.kmeans.cluster_centers())\n    centers = centers[centers[:, 0].argsort()]\n    true_centers = self.true_centers[self.true_centers[:, 0].argsort()]\n    self.assertAllClose(centers, true_centers, atol=0.04)\n\n  def test_transform(self):\n    self.kmeans.train(input_fn=self.input_fn(), steps=10)\n    centers = normalize(self.kmeans.cluster_centers())\n    true_transform = 1 - cosine_similarity(self.points, centers)\n    transform = list(\n        self.kmeans.transform(\n            input_fn=self.input_fn(batch_size=self.num_points, num_epochs=1)))\n    self.assertAllClose(transform, true_transform, atol=1e-3)\n\n  def test_predict(self):\n    max_steps = 10 * self.num_points // self.batch_size\n    self.kmeans.train(input_fn=self.input_fn(), max_steps=max_steps)\n    centers = normalize(self.kmeans.cluster_centers())\n\n    assignments = list(\n        self.kmeans.predict_cluster_index(\n            input_fn=self.input_fn(num_epochs=1, batch_size=self.num_points)))\n    self.assertAllClose(\n        centers[assignments],\n        self.true_centers[self.true_assignments],\n        atol=1e-2)\n\n    centers = centers[centers[:, 0].argsort()]\n    true_centers = self.true_centers[self.true_centers[:, 0].argsort()]\n    self.assertAllClose(centers, true_centers, atol=0.04)\n    score = self.kmeans.score(\n        input_fn=self.input_fn(batch_size=self.num_points))\n    self.assertAllClose(score, self.true_score, atol=1e-2)\n\n  def test_predict_kmeans_plus_plus(self):\n    # Most points are concentrated near one center. KMeans++ is likely to find\n    # the less populated centers.\n    points = np.array([[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3],\n                       [-3.1, -3.2], [-2.8, -3.], [-2.9, -3.1], [-3., -3.1],\n                       [-3., -3.1], [-3.2, -3.], [-3., -3.]],\n                      dtype=np.float32)\n    true_centers = np.array([\n        normalize(np.mean(normalize(points)[0:2, :], axis=0, keepdims=True))[0],\n        normalize(np.mean(normalize(points)[2:4, :], axis=0, keepdims=True))[0],\n        normalize(np.mean(normalize(points)[4:, :], axis=0, keepdims=True))[0]\n    ],\n                            dtype=np.float32)\n    true_assignments = [0] * 2 + [1] * 2 + [2] * 8\n    true_score = len(points) - np.tensordot(\n        normalize(points), true_centers[true_assignments])\n    kmeans = kmeans_lib.KMeansClustering(\n        3,\n        initial_clusters=self.initial_clusters,\n        distance_metric=kmeans_lib.KMeansClustering.COSINE_DISTANCE,\n        use_mini_batch=self.use_mini_batch,\n        mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration,\n        config=self.config(3))\n    kmeans.train(input_fn=lambda: (tf.constant(points), None), steps=30)\n\n    centers = normalize(kmeans.cluster_centers())\n    self.assertAllClose(\n        sorted(centers.tolist()), sorted(true_centers.tolist()), atol=1e-2)\n\n    def _input_fn():\n      return (tf.compat.v1.train.limit_epochs(\n          tf.constant(points), num_epochs=1), None)\n\n    assignments = list(kmeans.predict_cluster_index(input_fn=_input_fn))\n    self.assertAllClose(\n        centers[assignments], true_centers[true_assignments], atol=1e-2)\n\n    score = kmeans.score(input_fn=lambda: (tf.constant(points), None))\n    self.assertAllClose(score, true_score, atol=1e-2)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass MiniBatchKMeansCosineTest(KMeansCosineDistanceTest):\n\n  @property\n  def batch_size(self):\n    return 2\n\n  @property\n  def use_mini_batch(self):\n    return True\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass FullBatchAsyncKMeansCosineTest(KMeansCosineDistanceTest):\n\n  @property\n  def batch_size(self):\n    return 2\n\n  @property\n  def use_mini_batch(self):\n    return True\n\n  @property\n  def mini_batch_steps_per_iteration(self):\n    return self.num_points // self.batch_size\n\n\nclass KMeansBenchmark(benchmark.Benchmark):\n  \"\"\"Base class for benchmarks.\"\"\"\n\n  def SetUp(self,\n            dimension=50,\n            num_clusters=50,\n            points_per_cluster=10000,\n            center_norm=500,\n            cluster_width=20):\n    np.random.seed(123456)\n    self.num_clusters = num_clusters\n    self.num_points = num_clusters * points_per_cluster\n    self.centers = make_random_centers(\n        self.num_clusters, dimension, center_norm=center_norm)\n    self.points, _, scores = make_random_points(\n        self.centers, self.num_points, max_offset=cluster_width)\n    self.score = float(np.sum(scores))\n\n  def _report(self, num_iters, start, end, scores):\n    print(scores)\n    self.report_benchmark(\n        iters=num_iters,\n        wall_time=(end - start) / num_iters,\n        extras={\n            'true_sum_squared_distances': self.score,\n            'fit_scores': scores\n        })\n\n  def _fit(self, num_iters=10):\n    pass\n\n  def benchmark_01_2dim_5center_500point(self):\n    self.SetUp(dimension=2, num_clusters=5, points_per_cluster=100)\n    self._fit()\n\n  def benchmark_02_20dim_20center_10kpoint(self):\n    self.SetUp(dimension=20, num_clusters=20, points_per_cluster=500)\n    self._fit()\n\n  def benchmark_03_100dim_50center_50kpoint(self):\n    self.SetUp(dimension=100, num_clusters=50, points_per_cluster=1000)\n    self._fit()\n\n  def benchmark_03_100dim_50center_50kpoint_unseparated(self):\n    self.SetUp(\n        dimension=100,\n        num_clusters=50,\n        points_per_cluster=1000,\n        cluster_width=250)\n    self._fit()\n\n  def benchmark_04_100dim_500center_500kpoint(self):\n    self.SetUp(dimension=100, num_clusters=500, points_per_cluster=1000)\n    self._fit(num_iters=4)\n\n  def benchmark_05_100dim_500center_500kpoint_unseparated(self):\n    self.SetUp(\n        dimension=100,\n        num_clusters=500,\n        points_per_cluster=1000,\n        cluster_width=250)\n    self._fit(num_iters=4)\n\n\nclass TensorflowKMeansBenchmark(KMeansBenchmark):\n\n  def _fit(self, num_iters=10):\n    scores = []\n    start = time.time()\n    for i in range(num_iters):\n      print('Starting tensorflow KMeans: %d' % i)\n      tf_kmeans = kmeans_lib.KMeansClustering(\n          self.num_clusters,\n          initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT,\n          kmeans_plus_plus_num_retries=int(math.log(self.num_clusters) + 2),\n          seed=i * 42,\n          relative_tolerance=1e-6,\n          config=self.config(3))\n      tf_kmeans.train(\n          input_fn=lambda: (tf.constant(self.points), None), steps=50)\n      _ = tf_kmeans.cluster_centers()\n      scores.append(\n          tf_kmeans.score(input_fn=lambda: (tf.constant(self.points), None)))\n    self._report(num_iters, start, time.time(), scores)\n\n\nclass SklearnKMeansBenchmark(KMeansBenchmark):\n\n  def _fit(self, num_iters=10):\n    scores = []\n    start = time.time()\n    for i in range(num_iters):\n      print('Starting sklearn KMeans: %d' % i)\n      sklearn_kmeans = SklearnKMeans(\n          n_clusters=self.num_clusters,\n          init='k-means++',\n          max_iter=50,\n          n_init=1,\n          tol=1e-4,\n          random_state=i * 42)\n      sklearn_kmeans.train(self.points)\n      scores.append(sklearn_kmeans.inertia_)\n    self._report(num_iters, start, time.time(), scores)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass KMeansTestQueues(tf.test.TestCase):\n\n  def input_fn(self):\n\n    def _fn():\n      queue = tf.queue.FIFOQueue(\n          capacity=10, dtypes=tf.dtypes.float32, shapes=[10, 3])\n      enqueue_op = queue.enqueue(tf.zeros([10, 3], dtype=tf.dtypes.float32))\n      tf.compat.v1.train.queue_runner.add_queue_runner(\n          tf.compat.v1.train.queue_runner.QueueRunner(queue, [enqueue_op]))\n      return queue.dequeue(), None\n\n    return _fn\n\n  # This test makes sure that there are no deadlocks when using a QueueRunner.\n  # Note that since cluster initialization is dependent on inputs, if input\n  # is generated using a QueueRunner, one has to make sure that these runners\n  # are started before the initialization.\n  def test_queues(self):\n    kmeans = kmeans_lib.KMeansClustering(5)\n    kmeans.train(input_fn=self.input_fn(), steps=1)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Linear Estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.feature_column import feature_column_lib\nfrom tensorflow.python.feature_column import feature_column_v2 as fc_v2\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import variable_scope\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import optimizers\nfrom tensorflow_estimator.python.estimator.canned.linear_optimizer.python.utils import sdca_ops\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.head import binary_class_head\nfrom tensorflow_estimator.python.estimator.head import head_utils\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n# The default learning rate of 0.2 is a historical artifact of the initial\n# implementation, but seems a reasonable choice.\n_LEARNING_RATE = 0.2\n\n\n@estimator_export('estimator.experimental.LinearSDCA')\nclass LinearSDCA(object):\n  \"\"\"Stochastic Dual Coordinate Ascent helper for linear estimators.\n\n  Objects of this class are intended to be provided as the optimizer argument\n  (though LinearSDCA objects do not implement the `tf.train.Optimizer`\n  interface)\n  when creating `tf.estimator.LinearClassifier` or\n  `tf.estimator.LinearRegressor`.\n\n  SDCA can only be used with `LinearClassifier` and `LinearRegressor` under the\n  following conditions:\n\n    - Feature columns are of type V2.\n    - Multivalent categorical columns are not normalized. In other words the\n      `sparse_combiner` argument in the estimator constructor should be \"sum\".\n    - For classification: binary label.\n    - For regression: one-dimensional label.\n\n  Example usage:\n\n  ```python\n  real_feature_column = numeric_column(...)\n  sparse_feature_column = categorical_column_with_hash_bucket(...)\n  linear_sdca = tf.estimator.experimental.LinearSDCA(\n      example_id_column='example_id',\n      num_loss_partitions=1,\n      num_table_shards=1,\n      symmetric_l2_regularization=2.0)\n  classifier = tf.estimator.LinearClassifier(\n      feature_columns=[real_feature_column, sparse_feature_column],\n      weight_column=...,\n      optimizer=linear_sdca)\n  classifier.train(input_fn_train, steps=50)\n  classifier.evaluate(input_fn=input_fn_eval)\n  ```\n\n  Here the expectation is that the `input_fn_*` functions passed to train and\n  evaluate return a pair (dict, label_tensor) where dict has `example_id_column`\n  as `key` whose value is a `Tensor` of shape [batch_size] and dtype string.\n  num_loss_partitions defines sigma' in eq (11) of [3]. Convergence of (global)\n  loss is guaranteed if `num_loss_partitions` is larger or equal to the product\n  `(#concurrent train ops/per worker) x (#workers)`. Larger values for\n  `num_loss_partitions` lead to slower convergence. The recommended value for\n  `num_loss_partitions` in `tf.estimator` (where currently there is one process\n  per worker) is the number of workers running the train steps. It defaults to 1\n  (single machine).\n  `num_table_shards` defines the number of shards for the internal state\n  table, typically set to match the number of parameter servers for large\n  data sets.\n\n  The SDCA algorithm was originally introduced in [1] and it was followed by\n  the L1 proximal step [2], a distributed version [3] and adaptive sampling [4].\n  [1] www.jmlr.org/papers/volume14/shalev-shwartz13a/shalev-shwartz13a.pdf\n  [2] https://arxiv.org/pdf/1309.2375.pdf\n  [3] https://arxiv.org/pdf/1502.03508.pdf\n  [4] https://arxiv.org/pdf/1502.08053.pdf\n  Details specific to this implementation are provided in:\n  https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator/python/estimator/canned/linear_optimizer/doc/sdca.ipynb\n  \"\"\"\n\n  def __init__(self,\n               example_id_column,\n               num_loss_partitions=1,\n               num_table_shards=None,\n               symmetric_l1_regularization=0.0,\n               symmetric_l2_regularization=1.0,\n               adaptive=False):\n    \"\"\"Construct a new SDCA optimizer for linear estimators.\n\n    Args:\n      example_id_column: The column name containing the example ids.\n      num_loss_partitions: Number of workers.\n      num_table_shards: Number of shards of the internal state table, typically\n        set to match the number of parameter servers.\n      symmetric_l1_regularization: A float value, must be greater than or equal\n        to zero.\n      symmetric_l2_regularization: A float value, must be greater than zero and\n        should typically be greater than 1.\n      adaptive: A boolean indicating whether to use adaptive sampling.\n    \"\"\"\n\n    self._example_id_column = example_id_column\n    self._num_loss_partitions = num_loss_partitions\n    self._num_table_shards = num_table_shards\n    self._symmetric_l1_regularization = symmetric_l1_regularization\n    self._symmetric_l2_regularization = symmetric_l2_regularization\n    self._adaptive = adaptive\n\n  def _prune_and_unique_sparse_ids(self, id_weight_pair):\n    \"\"\"Remove duplicate and negative ids in a sparse tendor.\"\"\"\n\n    id_tensor = id_weight_pair.id_tensor\n    if id_weight_pair.weight_tensor:\n      weight_tensor = id_weight_pair.weight_tensor.values\n    else:\n      weight_tensor = tf.ones([tf.compat.v1.shape(id_tensor.indices)[0]],\n                              tf.dtypes.float32)\n\n    example_ids = tf.reshape(id_tensor.indices[:, 0], [-1])\n    flat_ids = tf.cast(\n        tf.reshape(id_tensor.values, [-1]), dtype=tf.dtypes.int64)\n    # Prune invalid IDs (< 0) from the flat_ids, example_ids, and\n    # weight_tensor.  These can come from looking up an OOV entry in the\n    # vocabulary (default value being -1).\n    is_id_valid = tf.math.greater_equal(flat_ids, 0)\n    flat_ids = tf.compat.v1.boolean_mask(flat_ids, is_id_valid)\n    example_ids = tf.compat.v1.boolean_mask(example_ids, is_id_valid)\n    weight_tensor = tf.compat.v1.boolean_mask(weight_tensor, is_id_valid)\n\n    projection_length = tf.math.reduce_max(flat_ids) + 1\n    # project ids based on example ids so that we can dedup ids that\n    # occur multiple times for a single example.\n    projected_ids = projection_length * example_ids + flat_ids\n\n    # Remove any redundant ids.\n    ids, idx = tf.unique(projected_ids)\n    # Keep only one example id per duplicated ids.\n    example_ids_filtered = tf.math.unsorted_segment_min(\n        example_ids, idx,\n        tf.compat.v1.shape(ids)[0])\n\n    # reproject ids back feature id space.\n    reproject_ids = (ids - projection_length * example_ids_filtered)\n\n    weights = tf.reshape(\n        tf.math.unsorted_segment_sum(weight_tensor, idx,\n                                     tf.compat.v1.shape(ids)[0]), [-1])\n    return sdca_ops._SparseFeatureColumn(  # pylint: disable=protected-access\n        example_ids_filtered, reproject_ids, weights)\n\n  def get_train_step(self, state_manager, weight_column_name, loss_type,\n                     feature_columns, features, targets, bias_var, global_step):\n    \"\"\"Returns the training operation of an SdcaModel optimizer.\"\"\"\n\n    batch_size = tf.compat.v1.shape(targets)[0]\n    cache = tf.compat.v2.__internal__.feature_column.FeatureTransformationCache(features)\n\n    # Iterate over all feature columns and create appropriate lists for dense\n    # and sparse features as well as dense and sparse weights (variables) for\n    # SDCA.\n    dense_features, dense_feature_weights = [], []\n    sparse_feature_with_values, sparse_feature_with_values_weights = [], []\n    for column in sorted(feature_columns, key=lambda x: x.name):\n      if isinstance(column, feature_column_lib.CategoricalColumn):\n        id_weight_pair = column.get_sparse_tensors(cache, state_manager)\n        sparse_feature_with_values.append(\n            self._prune_and_unique_sparse_ids(id_weight_pair))\n        # If a partitioner was used during variable creation, we will have a\n        # list of Variables here larger than 1.\n        sparse_feature_with_values_weights.append(\n            state_manager.get_variable(column, 'weights'))\n      elif isinstance(column, tf.compat.v2.__internal__.feature_column.DenseColumn):\n        if column.variable_shape.ndims != 1:\n          raise ValueError('Column %s has rank %d, larger than 1.' %\n                           (type(column).__name__, column.variable_shape.ndims))\n        dense_features.append(column.get_dense_tensor(cache, state_manager))\n        # For real valued columns, the variables list contains exactly one\n        # element.\n        dense_feature_weights.append(\n            state_manager.get_variable(column, 'weights'))\n      else:\n        raise ValueError('LinearSDCA does not support column type %s.' %\n                         type(column).__name__)\n\n    # Add the bias column\n    dense_features.append(tf.ones([batch_size, 1]))\n    dense_feature_weights.append(bias_var)\n\n    example_weights = tf.reshape(\n        features[weight_column_name],\n        shape=[-1]) if weight_column_name else tf.ones([batch_size])\n    example_ids = features[self._example_id_column]\n    training_examples = dict(\n        sparse_features=sparse_feature_with_values,\n        dense_features=dense_features,\n        example_labels=tf.compat.v1.to_float(tf.reshape(targets, shape=[-1])),\n        example_weights=example_weights,\n        example_ids=example_ids)\n    training_variables = dict(\n        sparse_features_weights=sparse_feature_with_values_weights,\n        dense_features_weights=dense_feature_weights)\n    sdca_model = sdca_ops._SDCAModel(  # pylint: disable=protected-access\n        examples=training_examples,\n        variables=training_variables,\n        options=dict(\n            symmetric_l1_regularization=self._symmetric_l1_regularization,\n            symmetric_l2_regularization=self._symmetric_l2_regularization,\n            adaptive=self._adaptive,\n            num_loss_partitions=self._num_loss_partitions,\n            num_table_shards=self._num_table_shards,\n            loss_type=loss_type))\n    train_op = sdca_model.minimize(global_step=global_step)\n    return sdca_model, train_op\n\n\ndef _get_default_optimizer_v2(feature_columns):\n  learning_rate = min(_LEARNING_RATE, 1.0 / math.sqrt(len(feature_columns)))\n  return tf_keras.optimizers.legacy.Ftrl(learning_rate=learning_rate)\n\n\ndef _get_default_optimizer(feature_columns):\n  learning_rate = min(_LEARNING_RATE, 1.0 / math.sqrt(len(feature_columns)))\n  return tf.compat.v1.train.FtrlOptimizer(learning_rate=learning_rate)\n\n\ndef _get_expanded_variable_list(var_list):\n  \"\"\"Given an iterable of variables, expands them if they are partitioned.\n\n  Args:\n    var_list: An iterable of variables.\n\n  Returns:\n    A list of variables where each partitioned variable is expanded to its\n    components.\n  \"\"\"\n  returned_list = []\n  for variable in var_list:\n    if (isinstance(variable, tf.Variable) or\n        tf.compat.v2.__internal__.ops.is_resource_variable(variable) or\n        isinstance(variable, tf.Tensor)):\n      returned_list.append(variable)  # Single variable/tensor case.\n    else:  # Must be a PartitionedVariable, so convert into a list.\n      returned_list.extend(list(variable))\n  return returned_list\n\n\n# TODO(rohanj): Consider making this a public utility method.\ndef _compute_fraction_of_zero(variables):\n  \"\"\"Given a linear variables list, compute the fraction of zero weights.\n\n  Args:\n    variables: A list or list of list of variables\n\n  Returns:\n    The fraction of zeros (sparsity) in the linear model.\n  \"\"\"\n  with ops.name_scope('zero_fraction'):\n    variables = tf.nest.flatten(variables)\n\n    with ops.name_scope('total_size'):\n      sizes = [\n          tf.compat.v1.size(x, out_type=tf.dtypes.int64) for x in variables\n      ]\n      total_size_int64 = tf.math.add_n(sizes)\n    with ops.name_scope('total_zero'):\n      total_zero_float32 = tf.math.add_n([\n          tf.compat.v1.cond(\n              tf.math.equal(size, tf.constant(0, dtype=tf.dtypes.int64)),\n              true_fn=lambda: tf.constant(0, dtype=tf.dtypes.float32),\n              false_fn=lambda: tf.math.zero_fraction(x) * tf.cast(\n                  size, dtype=tf.dtypes.float32),\n              name='zero_count') for x, size in zip(variables, sizes)\n      ])\n\n    with ops.name_scope('compute'):\n      total_size_float32 = tf.cast(\n          total_size_int64, dtype=tf.dtypes.float32, name='float32_size')\n      zero_fraction_or_nan = total_zero_float32 / total_size_float32\n\n    zero_fraction_or_nan = tf.identity(\n        zero_fraction_or_nan, name='zero_fraction_or_nan')\n    return zero_fraction_or_nan\n\n\ndef linear_logit_fn_builder_v2(units, feature_columns, sparse_combiner='sum'):\n  \"\"\"Function builder for a linear logit_fn.\n\n  Args:\n    units: An int indicating the dimension of the logit layer.\n    feature_columns: An iterable containing all the feature columns used by the\n      model.\n    sparse_combiner: A string specifying how to reduce if a categorical column\n      is multivalent.  One of \"mean\", \"sqrtn\", and \"sum\".\n\n  Returns:\n    A logit_fn (see below).\n\n  \"\"\"\n\n  def linear_logit_fn(features):\n    \"\"\"Linear model logit_fn.\n\n    Args:\n      features: This is the first item returned from the `input_fn` passed to\n        `train`, `evaluate`, and `predict`. This should be a single `Tensor` or\n        `dict` of same.\n\n    Returns:\n      A `Tensor` representing the logits.\n    \"\"\"\n    if not feature_column_lib.is_feature_column_v2(feature_columns):\n      raise ValueError(\n          'Received a feature column from TensorFlow v1, but this is a '\n          'TensorFlow v2 Estimator. Please either use v2 feature columns '\n          '(accessible via tf.feature_column.* in TF 2.x) with this '\n          'Estimator, or switch to a v1 Estimator for use with v1 feature '\n          'columns (accessible via tf.compat.v1.estimator.* and '\n          'tf.compat.v1.feature_column.*, respectively.')\n\n    linear_model = LinearModel(\n        feature_columns=feature_columns,\n        units=units,\n        sparse_combiner=sparse_combiner,\n        name='linear_model')\n    logits = linear_model(features)\n    bias = linear_model.bias\n\n    # We'd like to get all the non-bias variables associated with this\n    # LinearModel.\n    # TODO(rohanj): Figure out how to get shared embedding weights variable\n    # here.\n    variables = linear_model.variables\n    variables.remove(bias)\n\n    # Expand (potential) Partitioned variables\n    bias = _get_expanded_variable_list([bias])\n    variables = _get_expanded_variable_list(variables)\n\n    if units > 1:\n      tf.compat.v1.summary.histogram('bias', bias)\n    else:\n      # If units == 1, the bias value is a length-1 list of a scalar Tensor,\n      # so we should provide a scalar summary.\n      tf.compat.v1.summary.scalar('bias', bias[0][0])\n    tf.compat.v1.summary.scalar('fraction_of_zero_weights',\n                                _compute_fraction_of_zero(variables))\n    return logits\n\n  return linear_logit_fn\n\n\n@estimator_export(v1=['estimator.experimental.linear_logit_fn_builder'])\ndef linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):\n  \"\"\"Function builder for a linear logit_fn.\n\n  Args:\n    units: An int indicating the dimension of the logit layer.\n    feature_columns: An iterable containing all the feature columns used by the\n      model.\n    sparse_combiner: A string specifying how to reduce if a categorical column\n      is multivalent.  One of \"mean\", \"sqrtn\", and \"sum\".\n\n  Returns:\n    A logit_fn (see below).\n\n  \"\"\"\n\n  def linear_logit_fn(features):\n    \"\"\"Linear model logit_fn.\n\n    Args:\n      features: This is the first item returned from the `input_fn` passed to\n        `train`, `evaluate`, and `predict`. This should be a single `Tensor` or\n        `dict` of same.\n\n    Returns:\n      A `Tensor` representing the logits.\n    \"\"\"\n    if feature_column_lib.is_feature_column_v2(feature_columns):\n      linear_model = LinearModel(\n          feature_columns=feature_columns,\n          units=units,\n          sparse_combiner=sparse_combiner,\n          name='linear_model')\n      logits = linear_model(features)\n\n      # We'd like to get all the non-bias variables associated with this\n      # LinearModel.\n      # TODO(rohanj): Figure out how to get shared embedding weights variable\n      # here.\n      bias = linear_model.bias\n      variables = linear_model.variables\n      # Expand (potential) Partitioned variables\n      bias = _get_expanded_variable_list([bias])\n      variables = _get_expanded_variable_list(variables)\n      variables = [var for var in variables if var not in bias]\n\n      # Expand (potential) Partitioned variables\n      bias = _get_expanded_variable_list([bias])\n    else:\n      linear_model = feature_column._LinearModel(  # pylint: disable=protected-access\n          feature_columns=feature_columns,\n          units=units,\n          sparse_combiner=sparse_combiner,\n          name='linear_model')\n      logits = linear_model(features)\n      cols_to_vars = linear_model.cols_to_vars()\n      bias = cols_to_vars.pop('bias')\n      variables = cols_to_vars.values()\n      variables = _get_expanded_variable_list(variables)\n\n    if units > 1:\n      tf.compat.v1.summary.histogram('bias', bias)\n    else:\n      # If units == 1, the bias value is a length-1 list of a scalar Tensor,\n      # so we should provide a scalar summary.\n      tf.compat.v1.summary.scalar('bias', bias[0][0])\n    tf.compat.v1.summary.scalar('fraction_of_zero_weights',\n                                _compute_fraction_of_zero(variables))\n    return logits\n\n  return linear_logit_fn\n\n\ndef _sdca_model_fn(features, labels, mode, head, feature_columns, optimizer):\n  \"\"\"A model_fn for linear models that use the SDCA optimizer.\n\n  Args:\n    features: dict of `Tensor`.\n    labels: `Tensor` of shape `[batch_size]`.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `Head` instance.\n    feature_columns: An iterable containing all the feature columns used by the\n      model.\n    optimizer: a `LinearSDCA` instance.\n\n  Returns:\n    An `EstimatorSpec` instance.\n\n  Raises:\n    ValueError: mode or params are invalid, or features has the wrong type.\n  \"\"\"\n  assert feature_column_lib.is_feature_column_v2(feature_columns)\n  if isinstance(head,\n                (binary_class_head.BinaryClassHead,\n                 head_lib._BinaryLogisticHeadWithSigmoidCrossEntropyLoss)):  # pylint: disable=protected-access\n    loss_type = 'logistic_loss'\n  elif isinstance(head, (regression_head.RegressionHead,\n                         head_lib._RegressionHeadWithMeanSquaredErrorLoss)):  # pylint: disable=protected-access\n    assert head.logits_dimension == 1\n    loss_type = 'squared_loss'\n  else:\n    raise ValueError('Unsupported head type: {}'.format(head))\n\n  # The default name for LinearModel.\n  linear_model_name = 'linear_model'\n\n  # Name scope has no effect on variables in LinearModel, as it uses\n  # tf.get_variables() for variable creation. So we modify the model name to\n  # keep the variable names the same for checkpoint backward compatibility in\n  # canned Linear v2.\n  if isinstance(\n      head,\n      (binary_class_head.BinaryClassHead, regression_head.RegressionHead)):\n    linear_model_name = 'linear/linear_model'\n\n  linear_model = LinearModel(\n      feature_columns=feature_columns,\n      units=1,\n      sparse_combiner='sum',\n      name=linear_model_name)\n  logits = linear_model(features)\n\n  # We'd like to get all the non-bias variables associated with this\n  # LinearModel.\n  # TODO(rohanj): Figure out how to get shared embedding weights variable\n  # here.\n  bias = linear_model.bias\n  variables = linear_model.variables\n  # Expand (potential) Partitioned variables\n  bias = _get_expanded_variable_list([bias])\n  variables = _get_expanded_variable_list(variables)\n  variables = [var for var in variables if var not in bias]\n\n  tf.compat.v1.summary.scalar('bias', bias[0][0])\n  tf.compat.v1.summary.scalar('fraction_of_zero_weights',\n                              _compute_fraction_of_zero(variables))\n\n  if mode == ModeKeys.TRAIN:\n    sdca_model, train_op = optimizer.get_train_step(\n        linear_model.layer._state_manager,  # pylint: disable=protected-access\n        head._weight_column,  # pylint: disable=protected-access\n        loss_type,\n        feature_columns,\n        features,\n        labels,\n        linear_model.bias,\n        tf.compat.v1.train.get_global_step())\n\n    update_weights_hook = _SDCAUpdateWeightsHook(sdca_model, train_op)\n\n    model_fn_ops = head.create_estimator_spec(\n        features=features,\n        mode=mode,\n        labels=labels,\n        train_op_fn=lambda unused_loss_fn: train_op,\n        logits=logits)\n    return model_fn_ops._replace(\n        training_chief_hooks=(model_fn_ops.training_chief_hooks +\n                              (update_weights_hook,)))\n  else:\n    return head.create_estimator_spec(\n        features=features, mode=mode, labels=labels, logits=logits)\n\n\nclass _SDCAUpdateWeightsHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"SessionRunHook to update and shrink SDCA model weights.\"\"\"\n\n  def __init__(self, sdca_model, train_op):\n    self._sdca_model = sdca_model\n    self._train_op = train_op\n\n  def begin(self):\n    \"\"\"Construct the update_weights op.\n\n    The op is implicitly added to the default graph.\n    \"\"\"\n    self._update_op = self._sdca_model.update_weights(self._train_op)\n\n  def before_run(self, run_context):\n    \"\"\"Return the update_weights op so that it is executed during this run.\"\"\"\n    return tf.compat.v1.train.SessionRunArgs(self._update_op)\n\n\ndef _linear_model_fn_builder_v2(units,\n                                feature_columns,\n                                sparse_combiner='sum',\n                                features=None):\n  \"\"\"Function builder for a linear model_fn.\n\n  Args:\n    units: An int indicating the dimension of the logit layer.\n    feature_columns: An iterable containing all the feature columns used by the\n      model.\n    sparse_combiner: A string specifying how to reduce if a categorical column\n      is multivalent.  One of \"mean\", \"sqrtn\", and \"sum\".\n    features: This is the first item returned from the `input_fn` passed to\n      `train`, `evaluate`, and `predict`. This should be a single `Tensor` or\n      `dict` of same.\n\n  Returns:\n    A `Tensor` representing the logits.\n    A list of trainable variables.\n\n  \"\"\"\n  if not feature_column_lib.is_feature_column_v2(feature_columns):\n    raise ValueError(\n        'Received a feature column from TensorFlow v1, but this is a '\n        'TensorFlow v2 Estimator. Please either use v2 feature columns '\n        '(accessible via tf.feature_column.* in TF 2.x) with this '\n        'Estimator, or switch to a v1 Estimator for use with v1 feature '\n        'columns (accessible via tf.compat.v1.estimator.* and '\n        'tf.compat.v1.feature_column.*, respectively.')\n\n  # Name scope has no effect on variables in LinearModel, as it uses\n  # tf.get_variables() for variable creation. So we modify the model name to\n  # keep the variable names the same for checkpoint backward compatibility.\n  linear_model = LinearModel(\n      feature_columns=feature_columns,\n      units=units,\n      sparse_combiner=sparse_combiner,\n      name='linear/linear_model')\n  logits = linear_model(features)\n  bias = linear_model.bias\n\n  # We'd like to get all the non-bias variables associated with this\n  # LinearModel.\n  # TODO(rohanj): Figure out how to get shared embedding weights variable\n  # here.\n  variables = linear_model.variables\n  variables.remove(bias)\n\n  if units > 1:\n    tf.compat.v1.summary.histogram('bias', bias)\n  else:\n    # If units == 1, the bias value is a length-1 list of a scalar Tensor,\n    # so we should provide a scalar summary.\n    tf.compat.v1.summary.scalar('bias', bias[0])\n  tf.compat.v1.summary.scalar('fraction_of_zero_weights',\n                              _compute_fraction_of_zero(variables))\n\n  return logits, linear_model.variables\n\n\ndef _linear_model_fn_v2(features,\n                        labels,\n                        mode,\n                        head,\n                        feature_columns,\n                        optimizer,\n                        config,\n                        sparse_combiner='sum'):\n  \"\"\"A model_fn for linear models that use a gradient-based optimizer.\n\n  Args:\n    features: dict of `Tensor`.\n    labels: `Tensor` of shape `[batch_size, logits_dimension]`.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `Head` instance.\n    feature_columns: An iterable containing all the feature columns used by the\n      model.\n    optimizer: string, `Optimizer` object, or callable that defines the\n      optimizer to use for training. If `None`, will use a FTRL optimizer.\n    config: `RunConfig` object to configure the runtime settings.\n    sparse_combiner: A string specifying how to reduce if a categorical column\n      is multivalent.  One of \"mean\", \"sqrtn\", and \"sum\".\n\n  Returns:\n    An `EstimatorSpec` instance.\n\n  Raises:\n    ValueError: mode or params are invalid, or features has the wrong type.\n  \"\"\"\n  if not isinstance(features, dict):\n    raise ValueError('features should be a dictionary of `Tensor`s. '\n                     'Given type: {}'.format(type(features)))\n\n  del config\n\n  if isinstance(optimizer, LinearSDCA):\n    assert sparse_combiner == 'sum'\n    return _sdca_model_fn(features, labels, mode, head, feature_columns,\n                          optimizer)\n  else:\n    logits, trainable_variables = _linear_model_fn_builder_v2(\n        units=head.logits_dimension,\n        feature_columns=feature_columns,\n        sparse_combiner=sparse_combiner,\n        features=features)\n\n    # In TRAIN mode, create optimizer and assign global_step variable to\n    # optimizer.iterations to make global_step increased correctly, as Hooks\n    # relies on global step as step counter.\n    if mode == ModeKeys.TRAIN:\n      optimizer = optimizers.get_optimizer_instance_v2(\n          optimizer or _get_default_optimizer_v2(feature_columns),\n          learning_rate=_LEARNING_RATE)\n      optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()\n\n    return head.create_estimator_spec(\n        features=features,\n        mode=mode,\n        labels=labels,\n        optimizer=optimizer,\n        trainable_variables=trainable_variables,\n        logits=logits)\n\n\ndef _linear_model_fn(features,\n                     labels,\n                     mode,\n                     head,\n                     feature_columns,\n                     optimizer,\n                     partitioner,\n                     config,\n                     sparse_combiner='sum'):\n  \"\"\"A model_fn for linear models that use a gradient-based optimizer.\n\n  Args:\n    features: dict of `Tensor`.\n    labels: `Tensor` of shape `[batch_size, logits_dimension]`.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `Head` instance.\n    feature_columns: An iterable containing all the feature columns used by the\n      model.\n    optimizer: string, `Optimizer` object, or callable that defines the\n      optimizer to use for training. If `None`, will use a FTRL optimizer.\n    partitioner: Partitioner for variables.\n    config: `RunConfig` object to configure the runtime settings.\n    sparse_combiner: A string specifying how to reduce if a categorical column\n      is multivalent.  One of \"mean\", \"sqrtn\", and \"sum\".\n\n  Returns:\n    An `EstimatorSpec` instance.\n\n  Raises:\n    ValueError: mode or params are invalid, or features has the wrong type.\n  \"\"\"\n  if not isinstance(features, dict):\n    raise ValueError('features should be a dictionary of `Tensor`s. '\n                     'Given type: {}'.format(type(features)))\n\n  num_ps_replicas = config.num_ps_replicas if config else 0\n\n  partitioner = partitioner or (tf.compat.v1.min_max_variable_partitioner(\n      max_partitions=num_ps_replicas, min_slice_size=64 << 20))\n\n  with tf.compat.v1.variable_scope(\n      'linear', values=tuple(six.itervalues(features)),\n      partitioner=partitioner):\n\n    if isinstance(optimizer, LinearSDCA):\n      assert sparse_combiner == 'sum'\n      return _sdca_model_fn(features, labels, mode, head, feature_columns,\n                            optimizer)\n    else:\n      logit_fn = linear_logit_fn_builder(\n          units=head.logits_dimension,\n          feature_columns=feature_columns,\n          sparse_combiner=sparse_combiner,\n      )\n      logits = logit_fn(features=features)\n\n      optimizer = optimizers.get_optimizer_instance(\n          optimizer or _get_default_optimizer(feature_columns),\n          learning_rate=_LEARNING_RATE)\n\n      return head.create_estimator_spec(\n          features=features,\n          mode=mode,\n          labels=labels,\n          optimizer=optimizer,\n          logits=logits)\n\n\ndef _validate_linear_sdca_optimizer_for_linear_classifier(\n    feature_columns, n_classes, optimizer, sparse_combiner):\n  \"\"\"Helper function for the initialization of LinearClassifier.\"\"\"\n  if isinstance(optimizer, LinearSDCA):\n    if sparse_combiner != 'sum':\n      raise ValueError('sparse_combiner must be \"sum\" when optimizer '\n                       'is a LinearSDCA object.')\n    if not feature_column_lib.is_feature_column_v2(feature_columns):\n      raise ValueError('V2 feature columns required when optimizer '\n                       'is a LinearSDCA object.')\n    if n_classes > 2:\n      raise ValueError('LinearSDCA cannot be used in a multi-class setting.')\n\n\n@estimator_export('estimator.LinearClassifier', v1=[])\nclass LinearClassifierV2(estimator.EstimatorV2):\n  \"\"\"Linear classifier model.\n\n  Train a linear model to classify instances into one of multiple possible\n  classes. When number of possible classes is 2, this is binary classification.\n\n  Example:\n\n  ```python\n  categorical_column_a = categorical_column_with_hash_bucket(...)\n  categorical_column_b = categorical_column_with_hash_bucket(...)\n\n  categorical_feature_a_x_categorical_feature_b = crossed_column(...)\n\n  # Estimator using the default optimizer.\n  estimator = tf.estimator.LinearClassifier(\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b])\n\n  # Or estimator using the FTRL optimizer with regularization.\n  estimator = tf.estimator.LinearClassifier(\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b],\n      optimizer=tf_keras.optimizers.Ftrl(\n        learning_rate=0.1,\n        l1_regularization_strength=0.001\n      ))\n\n  # Or estimator using an optimizer with a learning rate decay.\n  estimator = tf.estimator.LinearClassifier(\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b],\n      optimizer=lambda: tf_keras.optimizers.Ftrl(\n          learning_rate=tf.exponential_decay(\n              learning_rate=0.1,\n              global_step=tf.get_global_step(),\n              decay_steps=10000,\n              decay_rate=0.96))\n\n  # Or estimator with warm-starting from a previous checkpoint.\n  estimator = tf.estimator.LinearClassifier(\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b],\n      warm_start_from=\"/path/to/checkpoint/dir\")\n\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_predict:\n    # Returns tf.data.Dataset of (x, None) tuple.\n    pass\n  estimator.train(input_fn=input_fn_train)\n  metrics = estimator.evaluate(input_fn=input_fn_eval)\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n    otherwise there will be a `KeyError`:\n\n  * if `weight_column` is not `None`, a feature with `key=weight_column` whose\n    value is a `Tensor`.\n  * for each `column` in `feature_columns`:\n    - if `column` is a `SparseColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedSparseColumn`, two features: the first with\n      `key` the id column name, the second with `key` the weight column name.\n      Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `RealValuedColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss is calculated by using softmax cross entropy.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               feature_columns,\n               model_dir=None,\n               n_classes=2,\n               weight_column=None,\n               label_vocabulary=None,\n               optimizer='Ftrl',\n               config=None,\n               warm_start_from=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               sparse_combiner='sum'):\n    \"\"\"Construct a `LinearClassifier` estimator object.\n\n    Args:\n      feature_columns: An iterable containing all the feature columns used by\n        the model. All items in the set should be instances of classes derived\n        from `FeatureColumn`.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      n_classes: number of label classes. Default is binary classification. Note\n        that class labels are integers representing the class index (i.e. values\n        from 0 to n_classes-1). For arbitrary label values (e.g. string labels),\n        convert to class indices first.\n      weight_column: A string or a `_NumericColumn` created by\n        `tf.feature_column.numeric_column` defining feature column representing\n        weights. It is used to down weight or boost examples during training. It\n        will be multiplied by the loss of the example. If it is a string, it is\n        used as a key to fetch weight tensor from the `features`. If it is a\n        `_NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n        weight_column.normalizer_fn is applied on it to get weight tensor.\n      label_vocabulary: A list of strings represents possible label values. If\n        given, labels must be string type and have any value in\n        `label_vocabulary`. If it is not given, that means labels are already\n        encoded as integer or float within [0, 1] for `n_classes=2` and encoded\n        as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also\n        there will be errors if vocabulary is not provided and labels are\n        string.\n      optimizer: An instance of `tf_keras.optimizers.*` or\n        `tf.estimator.experimental.LinearSDCA` used to train the model. Can also\n        be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or\n        callable. Defaults to FTRL optimizer.\n      config: `RunConfig` object to configure the runtime settings.\n      warm_start_from: A string filepath to a checkpoint to warm-start from, or\n        a `WarmStartSettings` object to fully configure warm-starting.  If the\n        string filepath is provided instead of a `WarmStartSettings`, then all\n        weights and biases are warm-started, and it is assumed that vocabularies\n        and Tensor names are unchanged.\n      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how\n        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n      sparse_combiner: A string specifying how to reduce if a categorical column\n        is multivalent.  One of \"mean\", \"sqrtn\", and \"sum\" -- these are\n        effectively different ways to do example-level normalization, which can\n        be useful for bag-of-words features. for more details, see\n        `tf.feature_column.linear_model`.\n\n    Returns:\n      A `LinearClassifier` estimator.\n\n    Raises:\n      ValueError: if n_classes < 2.\n    \"\"\"\n    _validate_linear_sdca_optimizer_for_linear_classifier(\n        feature_columns=feature_columns,\n        n_classes=n_classes,\n        optimizer=optimizer,\n        sparse_combiner=sparse_combiner)\n    estimator._canned_estimator_api_gauge.get_cell('Classifier').set('Linear')  # pylint: disable=protected-access\n\n    head = head_utils.binary_or_multi_class_head(\n        n_classes,\n        weight_column=weight_column,\n        label_vocabulary=label_vocabulary,\n        loss_reduction=loss_reduction)\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared _linear_model_fn.\"\"\"\n      return _linear_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          config=config,\n          sparse_combiner=sparse_combiner)\n\n    super(LinearClassifierV2, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export(v1=['estimator.LinearClassifier'])  # pylint: disable=missing-docstring\nclass LinearClassifier(estimator.Estimator):\n  __doc__ = LinearClassifierV2.__doc__.replace('SUM_OVER_BATCH_SIZE', 'SUM')\n\n  def __init__(self,\n               feature_columns,\n               model_dir=None,\n               n_classes=2,\n               weight_column=None,\n               label_vocabulary=None,\n               optimizer='Ftrl',\n               config=None,\n               partitioner=None,\n               warm_start_from=None,\n               loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n               sparse_combiner='sum'):\n    _validate_linear_sdca_optimizer_for_linear_classifier(\n        feature_columns=feature_columns,\n        n_classes=n_classes,\n        optimizer=optimizer,\n        sparse_combiner=sparse_combiner)\n    estimator._canned_estimator_api_gauge.get_cell('Classifier').set('Linear')  # pylint: disable=protected-access\n\n    head = head_lib._binary_logistic_or_multi_class_head(  # pylint: disable=protected-access\n        n_classes, weight_column, label_vocabulary, loss_reduction)\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared _linear_model_fn.\"\"\"\n      return _linear_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          partitioner=partitioner,\n          config=config,\n          sparse_combiner=sparse_combiner)\n\n    super(LinearClassifier, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export('estimator.LinearEstimator', v1=[])\nclass LinearEstimatorV2(estimator.EstimatorV2):\n  \"\"\"An estimator for TensorFlow linear models with user-specified head.\n\n  Example:\n\n  ```python\n  categorical_column_a = categorical_column_with_hash_bucket(...)\n  categorical_column_b = categorical_column_with_hash_bucket(...)\n\n  categorical_feature_a_x_categorical_feature_b = crossed_column(...)\n\n  # Estimator using the default optimizer.\n  estimator = tf.estimator.LinearEstimator(\n      head=tf.estimator.MultiLabelHead(n_classes=3),\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b])\n\n  # Or estimator using an optimizer with a learning rate decay.\n  estimator = tf.estimator.LinearEstimator(\n      head=tf.estimator.MultiLabelHead(n_classes=3),\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b],\n      optimizer=lambda: tf_keras.optimizers.Ftrl(\n          learning_rate=tf.compat.v1.train.exponential_decay(\n              learning_rate=0.1,\n              global_step=tf.compat.v1.train.get_global_step(),\n              decay_steps=10000,\n              decay_rate=0.96))\n\n  # Or estimator using the FTRL optimizer with regularization.\n  estimator = tf.estimator.LinearEstimator(\n      head=tf.estimator.MultiLabelHead(n_classes=3),\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b])\n      optimizer=tf_keras.optimizers.Ftrl(\n          learning_rate=0.1,\n          l1_regularization_strength=0.001\n      ))\n\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_predict:\n    # Returns tf.data.Dataset of (x, None) tuple.\n    pass\n  estimator.train(input_fn=input_fn_train, steps=100)\n  metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n  otherwise there will be a `KeyError`:\n\n  * if `weight_column` is not `None`, a feature with `key=weight_column` whose\n    value is a `Tensor`.\n  * for each `column` in `feature_columns`:\n    - if `column` is a `CategoricalColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedCategoricalColumn`, two features: the first\n      with `key` the id column name, the second with `key` the weight column\n      name. Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `DenseColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss and predicted output are determined by the specified head.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               head,\n               feature_columns,\n               model_dir=None,\n               optimizer='Ftrl',\n               config=None,\n               sparse_combiner='sum',\n               warm_start_from=None):\n    \"\"\"Initializes a `LinearEstimator` instance.\n\n    Args:\n      head: A `Head` instance constructed with a method such as\n        `tf.estimator.MultiLabelHead`.\n      feature_columns: An iterable containing all the feature columns used by\n        the model. All items in the set should be instances of classes derived\n        from `FeatureColumn`.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      optimizer: An instance of `tf_keras.optimizers.*` used to train the model.\n        Can also be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp',\n        'SGD'), or callable. Defaults to FTRL optimizer.\n      config: `RunConfig` object to configure the runtime settings.\n      sparse_combiner: A string specifying how to reduce if a categorical column\n        is multivalent.  One of \"mean\", \"sqrtn\", and \"sum\" -- these are\n        effectively different ways to do example-level normalization, which can\n        be useful for bag-of-words features. for more details, see\n        `tf.feature_column.linear_model`.\n      warm_start_from: A string filepath to a checkpoint to warm-start from, or\n        a `WarmStartSettings` object to fully configure warm-starting.  If the\n        string filepath is provided instead of a `WarmStartSettings`, then all\n        weights and biases are warm-started, and it is assumed that vocabularies\n        and Tensor names are unchanged.\n    \"\"\"\n\n    def _model_fn(features, labels, mode, config):\n      return _linear_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          config=config,\n          sparse_combiner=sparse_combiner)\n\n    estimator._canned_estimator_api_gauge.get_cell('Estimator').set('Linear')  # pylint: disable=protected-access\n    super(LinearEstimatorV2, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export(v1=['estimator.LinearEstimator'])  # pylint: disable=missing-docstring\nclass LinearEstimator(estimator.Estimator):\n  __doc__ = LinearEstimatorV2.__doc__\n\n  def __init__(self,\n               head,\n               feature_columns,\n               model_dir=None,\n               optimizer='Ftrl',\n               config=None,\n               partitioner=None,\n               sparse_combiner='sum',\n               warm_start_from=None):\n    \"\"\"Initializes a `LinearEstimator` instance.\n\n    Args:\n      head: A `_Head` instance constructed with a method such as\n        `tf.contrib.estimator.multi_label_head`.\n      feature_columns: An iterable containing all the feature columns used by\n        the model. All items in the set should be instances of classes derived\n        from `FeatureColumn`.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      optimizer: An instance of `tf.Optimizer` used to train the model. Can also\n        be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or\n        callable. Defaults to FTRL optimizer.\n      config: `RunConfig` object to configure the runtime settings.\n      partitioner: Optional. Partitioner for input layer.\n      sparse_combiner: A string specifying how to reduce if a categorical column\n        is multivalent.  One of \"mean\", \"sqrtn\", and \"sum\" -- these are\n        effectively different ways to do example-level normalization, which can\n        be useful for bag-of-words features. for more details, see\n        `tf.feature_column.linear_model`.\n      warm_start_from: A string filepath to a checkpoint to warm-start from, or\n        a `WarmStartSettings` object to fully configure warm-starting.  If the\n        string filepath is provided instead of a `WarmStartSettings`, then all\n        weights and biases are warm-started, and it is assumed that vocabularies\n        and Tensor names are unchanged.\n    \"\"\"\n\n    def _model_fn(features, labels, mode, config):\n      return _linear_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          partitioner=partitioner,\n          config=config,\n          sparse_combiner=sparse_combiner)\n\n    estimator._canned_estimator_api_gauge.get_cell('Estimator').set('Linear')  # pylint: disable=protected-access\n    super(LinearEstimator, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config,\n        warm_start_from=warm_start_from)\n\n\ndef _validate_linear_sdca_optimizer_for_linear_regressor(\n    feature_columns, label_dimension, optimizer, sparse_combiner):\n  \"\"\"Helper function for the initialization of LinearRegressor.\"\"\"\n  if isinstance(optimizer, LinearSDCA):\n    if sparse_combiner != 'sum':\n      raise ValueError('sparse_combiner must be \"sum\" when optimizer '\n                       'is a LinearSDCA object.')\n    if not feature_column_lib.is_feature_column_v2(feature_columns):\n      raise ValueError('V2 feature columns required when optimizer '\n                       'is a LinearSDCA object.')\n    if label_dimension > 1:\n      raise ValueError('LinearSDCA can only be used with one-dimensional '\n                       'label.')\n\n\n@estimator_export('estimator.LinearRegressor', v1=[])\nclass LinearRegressorV2(estimator.EstimatorV2):\n  \"\"\"An estimator for TensorFlow Linear regression problems.\n\n  Train a linear regression model to predict label value given observation of\n  feature values.\n\n  Example:\n\n  ```python\n  categorical_column_a = categorical_column_with_hash_bucket(...)\n  categorical_column_b = categorical_column_with_hash_bucket(...)\n\n  categorical_feature_a_x_categorical_feature_b = crossed_column(...)\n\n  # Estimator using the default optimizer.\n  estimator = tf.estimator.LinearRegressor(\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b])\n\n  # Or estimator using the FTRL optimizer with regularization.\n  estimator = tf.estimator.LinearRegressor(\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b],\n      optimizer=tf_keras.optimizers.Ftrl(\n        learning_rate=0.1,\n        l1_regularization_strength=0.001\n      ))\n\n  # Or estimator using an optimizer with a learning rate decay.\n  estimator = tf.estimator.LinearRegressor(\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b],\n      optimizer=lambda: tf_keras.optimizers.Ftrl(\n          learning_rate=tf.compat.v1.train.exponential_decay(\n              learning_rate=0.1,\n              global_step=tf.compat.v1.train.get_global_step(),\n              decay_steps=10000,\n              decay_rate=0.96))\n\n  # Or estimator with warm-starting from a previous checkpoint.\n  estimator = tf.estimator.LinearRegressor(\n      feature_columns=[categorical_column_a,\n                       categorical_feature_a_x_categorical_feature_b],\n      warm_start_from=\"/path/to/checkpoint/dir\")\n\n\n  # Input builders\n  def input_fn_train:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_eval:\n    # Returns tf.data.Dataset of (x, y) tuple where y represents label's class\n    # index.\n    pass\n  def input_fn_predict:\n    # Returns tf.data.Dataset of (x, None) tuple.\n    pass\n  estimator.train(input_fn=input_fn_train)\n  metrics = estimator.evaluate(input_fn=input_fn_eval)\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n    otherwise there will be a KeyError:\n\n  * if `weight_column` is not `None`, a feature with `key=weight_column` whose\n    value is a `Tensor`.\n  * for each `column` in `feature_columns`:\n    - if `column` is a `SparseColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedSparseColumn`, two features: the first with\n      `key` the id column name, the second with `key` the weight column name.\n      Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `RealValuedColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss is calculated by using mean squared error.\n\n  @compatibility(eager)\n  Estimators can be used while eager execution is enabled. Note that `input_fn`\n  and all hooks are executed inside a graph context, so they have to be written\n  to be compatible with graph mode. Note that `input_fn` code using `tf.data`\n  generally works in both graph and eager modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               feature_columns,\n               model_dir=None,\n               label_dimension=1,\n               weight_column=None,\n               optimizer='Ftrl',\n               config=None,\n               warm_start_from=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               sparse_combiner='sum'):\n    \"\"\"Initializes a `LinearRegressor` instance.\n\n    Args:\n      feature_columns: An iterable containing all the feature columns used by\n        the model. All items in the set should be instances of classes derived\n        from `FeatureColumn`.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      label_dimension: Number of regression targets per example. This is the\n        size of the last dimension of the labels and logits `Tensor` objects\n        (typically, these have shape `[batch_size, label_dimension]`).\n      weight_column: A string or a `NumericColumn` created by\n        `tf.feature_column.numeric_column` defining feature column representing\n        weights. It is used to down weight or boost examples during training. It\n        will be multiplied by the loss of the example. If it is a string, it is\n        used as a key to fetch weight tensor from the `features`. If it is a\n        `NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n        weight_column.normalizer_fn is applied on it to get weight tensor.\n      optimizer: An instance of `tf_keras.optimizers.*` or\n        `tf.estimator.experimental.LinearSDCA` used to train the model. Can also\n        be a string (one of 'Adagrad', 'Adam', 'Ftrl', 'RMSProp', 'SGD'), or\n        callable. Defaults to FTRL optimizer.\n      config: `RunConfig` object to configure the runtime settings.\n      warm_start_from: A string filepath to a checkpoint to warm-start from, or\n        a `WarmStartSettings` object to fully configure warm-starting.  If the\n        string filepath is provided instead of a `WarmStartSettings`, then all\n        weights and biases are warm-started, and it is assumed that vocabularies\n        and Tensor names are unchanged.\n      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how\n        to reduce training loss over batch. Defaults to `SUM`.\n      sparse_combiner: A string specifying how to reduce if a categorical column\n        is multivalent.  One of \"mean\", \"sqrtn\", and \"sum\" -- these are\n        effectively different ways to do example-level normalization, which can\n        be useful for bag-of-words features. for more details, see\n        `tf.feature_column.linear_model`.\n    \"\"\"\n    _validate_linear_sdca_optimizer_for_linear_regressor(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        optimizer=optimizer,\n        sparse_combiner=sparse_combiner)\n\n    head = regression_head.RegressionHead(\n        label_dimension=label_dimension,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction)\n    estimator._canned_estimator_api_gauge.get_cell('Regressor').set('Linear')  # pylint: disable=protected-access\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared _linear_model_fn.\"\"\"\n      return _linear_model_fn_v2(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          config=config,\n          sparse_combiner=sparse_combiner)\n\n    super(LinearRegressorV2, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\n@estimator_export(v1=['estimator.LinearRegressor'])  # pylint: disable=missing-docstring\nclass LinearRegressor(estimator.Estimator):\n  __doc__ = LinearRegressorV2.__doc__.replace('SUM_OVER_BATCH_SIZE', 'SUM')\n\n  def __init__(self,\n               feature_columns,\n               model_dir=None,\n               label_dimension=1,\n               weight_column=None,\n               optimizer='Ftrl',\n               config=None,\n               partitioner=None,\n               warm_start_from=None,\n               loss_reduction=tf.compat.v1.losses.Reduction.SUM,\n               sparse_combiner='sum'):\n    _validate_linear_sdca_optimizer_for_linear_regressor(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        optimizer=optimizer,\n        sparse_combiner=sparse_combiner)\n\n    head = head_lib._regression_head(  # pylint: disable=protected-access\n        label_dimension=label_dimension,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction)\n    estimator._canned_estimator_api_gauge.get_cell('Regressor').set('Linear')  # pylint: disable=protected-access\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"Call the defined shared _linear_model_fn.\"\"\"\n      return _linear_model_fn(\n          features=features,\n          labels=labels,\n          mode=mode,\n          head=head,\n          feature_columns=tuple(feature_columns or []),\n          optimizer=optimizer,\n          partitioner=partitioner,\n          config=config,\n          sparse_combiner=sparse_combiner)\n\n    super(LinearRegressor, self).__init__(\n        model_fn=_model_fn,\n        model_dir=model_dir,\n        config=config,\n        warm_start_from=warm_start_from)\n\n\nclass _LinearModelLayer(tf_keras.layers.Layer):\n  \"\"\"Layer that contains logic for `LinearModel`.\"\"\"\n\n  def __init__(self,\n               feature_columns,\n               units=1,\n               sparse_combiner='sum',\n               trainable=True,\n               name=None,\n               **kwargs):\n    super(_LinearModelLayer, self).__init__(\n        name=name, trainable=trainable, **kwargs)\n\n    self._feature_columns = fc_v2._normalize_feature_columns(feature_columns)  # pylint: disable=protected-access\n    for column in self._feature_columns:\n      if not isinstance(column, (tf.compat.v2.__internal__.feature_column.DenseColumn, fc_v2.CategoricalColumn)):\n        raise ValueError(\n            'Items of feature_columns must be either a '\n            'DenseColumn or CategoricalColumn. Given: {}'.format(column))\n\n    self._units = units\n    self._sparse_combiner = sparse_combiner\n\n    self._state_manager = tf.compat.v2.__internal__.feature_column.StateManager(self, self.trainable)  # pylint: disable=protected-access\n    self.bias = None\n\n  def build(self, _):\n    # We need variable scopes for now because we want the variable partitioning\n    # information to percolate down. We also use _pure_variable_scope's here\n    # since we want to open up a name_scope in the `call` method while creating\n    # the ops.\n    with variable_scope._pure_variable_scope(self.name):  # pylint: disable=protected-access\n      for column in self._feature_columns:\n        with variable_scope._pure_variable_scope(  # pylint: disable=protected-access\n            fc_v2._sanitize_column_name_for_variable_scope(column.name)):  # pylint: disable=protected-access\n          # Create the state for each feature column\n          column.create_state(self._state_manager)\n\n          # Create a weight variable for each column.\n          if isinstance(column, fc_v2.CategoricalColumn):\n            first_dim = column.num_buckets\n          else:\n            first_dim = column.variable_shape.num_elements()\n          self._state_manager.create_variable(\n              column,\n              name='weights',\n              dtype=tf.float32,\n              shape=(first_dim, self._units),\n              initializer=tf_keras.initializers.zeros(),\n              trainable=self.trainable)\n\n      # Create a bias variable.\n      self.bias = self.add_weight(\n          name='bias_weights',\n          dtype=tf.float32,\n          shape=[self._units],\n          initializer=tf_keras.initializers.zeros(),\n          trainable=self.trainable,\n          use_resource=True,\n          # TODO(rohanj): Get rid of this hack once we have a mechanism for\n          # specifying a default partitioner for an entire layer. In that case,\n          # the default getter for Layers should work.\n          getter=tf.compat.v1.get_variable)\n\n    super(_LinearModelLayer, self).build(None)\n\n  def call(self, features):\n    if not isinstance(features, dict):\n      raise ValueError('We expected a dictionary here. Instead we got: {}'\n                       .format(features))\n    with ops.name_scope(self.name):\n      transformation_cache = tf.compat.v2.__internal__.feature_column.FeatureTransformationCache(features)\n      weighted_sums = []\n      for column in self._feature_columns:\n        with ops.name_scope(\n            fc_v2._sanitize_column_name_for_variable_scope(column.name)):  # pylint: disable=protected-access\n          # All the weights used in the linear model are owned by the state\n          # manager associated with this Linear Model.\n          weight_var = self._state_manager.get_variable(column, 'weights')\n\n          weighted_sum = fc_v2._create_weighted_sum(  # pylint: disable=protected-access\n              column=column,\n              transformation_cache=transformation_cache,\n              state_manager=self._state_manager,\n              sparse_combiner=self._sparse_combiner,\n              weight_var=weight_var)\n          weighted_sums.append(weighted_sum)\n\n      fc_v2._verify_static_batch_size_equality(  # pylint: disable=protected-access\n          weighted_sums, self._feature_columns)\n      predictions_no_bias = tf.math.add_n(\n          weighted_sums, name='weighted_sum_no_bias')\n      predictions = tf.nn.bias_add(\n          predictions_no_bias, self.bias, name='weighted_sum')\n      return predictions\n\n  def get_config(self):\n    # Import here to avoid circular imports.\n    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top\n    column_configs = serialization.serialize_feature_columns(\n        self._feature_columns)\n    config = {\n        'feature_columns': column_configs,\n        'units': self._units,\n        'sparse_combiner': self._sparse_combiner\n    }\n\n    base_config = super(  # pylint: disable=bad-super-call\n        _LinearModelLayer, self).get_config()\n    return dict(list(base_config.items()) + list(config.items()))\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    # Import here to avoid circular imports.\n    from tensorflow.python.feature_column import serialization  # pylint: disable=g-import-not-at-top\n    config_cp = config.copy()\n    columns = serialization.deserialize_feature_columns(\n        config_cp['feature_columns'], custom_objects=custom_objects)\n\n    del config_cp['feature_columns']\n    return cls(feature_columns=columns, **config_cp)\n\n\nclass LinearModel(tf_keras.Model):\n  \"\"\"Produces a linear prediction `Tensor` based on given `feature_columns`.\n\n  This layer generates a weighted sum based on output dimension `units`.\n  Weighted sum refers to logits in classification problems. It refers to the\n  prediction itself for linear regression problems.\n\n  Note on supported columns: `LinearLayer` treats categorical columns as\n  `indicator_column`s. To be specific, assume the input as `SparseTensor` looks\n  like:\n\n  ```python\n    shape = [2, 2]\n    {\n        [0, 0]: \"a\"\n        [1, 0]: \"b\"\n        [1, 1]: \"c\"\n    }\n  ```\n  `linear_model` assigns weights for the presence of \"a\", \"b\", \"c' implicitly,\n  just like `indicator_column`, while `input_layer` explicitly requires wrapping\n  each of categorical columns with an `embedding_column` or an\n  `indicator_column`.\n\n  Example of usage:\n\n  ```python\n  price = numeric_column('price')\n  price_buckets = bucketized_column(price, boundaries=[0., 10., 100., 1000.])\n  keywords = categorical_column_with_hash_bucket(\"keywords\", 10K)\n  keywords_price = crossed_column('keywords', price_buckets, ...)\n  columns = [price_buckets, keywords, keywords_price ...]\n  linear_model = LinearLayer(columns)\n\n  features = tf.io.parse_example(..., features=make_parse_example_spec(columns))\n  prediction = linear_model(features)\n  ```\n  \"\"\"\n\n  def __init__(self,\n               feature_columns,\n               units=1,\n               sparse_combiner='sum',\n               trainable=True,\n               name=None,\n               **kwargs):\n    \"\"\"Constructs a LinearLayer.\n\n    Args:\n      feature_columns: An iterable containing the FeatureColumns to use as\n        inputs to your model. All items should be instances of classes derived\n        from `_FeatureColumn`s.\n      units: An integer, dimensionality of the output space. Default value is 1.\n      sparse_combiner: A string specifying how to reduce if a categorical column\n        is multivalent. Except `numeric_column`, almost all columns passed to\n        `linear_model` are considered as categorical columns.  It combines each\n        categorical column independently. Currently \"mean\", \"sqrtn\" and \"sum\"\n        are supported, with \"sum\" the default for linear model. \"sqrtn\" often\n        achieves good accuracy, in particular with bag-of-words columns.\n          * \"sum\": do not normalize features in the column\n          * \"mean\": do l1 normalization on features in the column\n          * \"sqrtn\": do l2 normalization on features in the column\n        For example, for two features represented as the categorical columns:\n\n          ```python\n          # Feature 1\n\n          shape = [2, 2]\n          {\n              [0, 0]: \"a\"\n              [0, 1]: \"b\"\n              [1, 0]: \"c\"\n          }\n\n          # Feature 2\n\n          shape = [2, 3]\n          {\n              [0, 0]: \"d\"\n              [1, 0]: \"e\"\n              [1, 1]: \"f\"\n              [1, 2]: \"g\"\n          }\n          ```\n\n        with `sparse_combiner` as \"mean\", the linear model outputs conceptually\n        are\n        ```\n        y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0\n        y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1\n        ```\n        where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight\n        assigned to the presence of `x` in the input features.\n      trainable: If `True` also add the variable to the graph collection\n        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).\n      name: Name to give to the Linear Model. All variables and ops created will\n        be scoped by this name.\n      **kwargs: Keyword arguments to construct a layer.\n\n    Raises:\n      ValueError: if an item in `feature_columns` is neither a `DenseColumn`\n        nor `CategoricalColumn`.\n    \"\"\"\n\n    super(LinearModel, self).__init__(name=name, **kwargs)\n    self.layer = _LinearModelLayer(\n        feature_columns,\n        units,\n        sparse_combiner,\n        trainable,\n        name=self.name,\n        **kwargs)\n\n  def call(self, features):\n    \"\"\"Returns a `Tensor` the represents the predictions of a linear model.\n\n    Args:\n      features: A mapping from key to tensors. `_FeatureColumn`s look up via\n        these keys. For example `numeric_column('price')` will look at 'price'\n        key in this dict. Values are `Tensor` or `SparseTensor` depending on\n        corresponding `_FeatureColumn`.\n\n    Returns:\n      A `Tensor` which represents predictions/logits of a linear model. Its\n      shape is (batch_size, units) and its dtype is `float32`.\n\n    Raises:\n      ValueError: If features are not a dictionary.\n    \"\"\"\n    return self.layer(features)\n\n  @property\n  def bias(self):\n    return self.layer.bias\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_estimator_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for LinearEstimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import linear\nfrom tensorflow_estimator.python.estimator.canned import linear_testing_utils\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.head import multi_class_head\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n\ndef _linear_estimator_fn(weight_column=None, label_dimension=1, **kwargs):\n  \"\"\"Returns a LinearEstimator that uses regression_head.\"\"\"\n  return linear.LinearEstimatorV2(\n      head=regression_head.RegressionHead(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          # Tests in core (from which this test inherits) test the sum loss.\n          loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE),\n      **kwargs)\n\n\ndef _linear_estimator_classifier_fn(n_classes=3, **kwargs):\n  return linear.LinearEstimatorV2(\n      head=multi_class_head.MultiClassHead(n_classes=n_classes), **kwargs)\n\n\nclass LinearEstimatorEvaluateTest(\n    linear_testing_utils.BaseLinearRegressorEvaluationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_estimator_fn)\n\n\nclass LinearEstimatorPredictTest(\n    linear_testing_utils.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_estimator_fn)\n\n\nclass LinearEstimatorTrainTest(\n    linear_testing_utils.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_estimator_fn)\n\n\nclass LinearEstimatorWarmStartingTest(\n    linear_testing_utils.BaseLinearWarmStartingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearWarmStartingTest.__init__(\n        self,\n        _linear_estimator_classifier_fn,\n        _linear_estimator_fn)\n\n\nclass LinearEstimatorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self,\n                          train_input_fn,\n                          eval_input_fn,\n                          predict_input_fn,\n                          input_dimension,\n                          label_dimension,\n                          batch_size,\n                          optimizer='Ftrl'):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = linear.LinearEstimatorV2(\n        head=regression_head.RegressionHead(label_dimension=label_dimension),\n        feature_columns=feature_columns,\n        model_dir=self._model_dir,\n        optimizer=optimizer)\n\n    # Train\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # Evaluate\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # Predict\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # Export\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _create_input_fn(self, label_dimension, batch_size):\n    \"\"\"Creates input_fn for integration test.\"\"\"\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    return train_input_fn, eval_input_fn, predict_input_fn\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    train_input_fn, eval_input_fn, predict_input_fn = self._create_input_fn(\n        label_dimension, batch_size)\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n  def test_numpy_input_fn_with_optimizer_instance(self):\n    \"\"\"Tests complete flow with optimizer_v2 instance.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    train_input_fn, eval_input_fn, predict_input_fn = self._create_input_fn(\n        label_dimension, batch_size)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        optimizer=tf_keras.optimizers.legacy.Ftrl(0.01))  # Test with optimizer_v2 instance\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_model_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for feature_column.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.core.protobuf import rewriter_config_pb2\nfrom tensorflow.python.feature_column import feature_column_v2 as fc\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.platform import flags\nfrom tensorflow_estimator.python.estimator.canned import linear\n\n\ndef _initialized_session(config=None):\n  sess = tf.compat.v1.Session(config=config)\n  sess.run(tf.compat.v1.initializers.global_variables())\n  sess.run(tf.compat.v1.tables_initializer())\n  return sess\n\n\ndef get_linear_model_bias(name='linear_model'):\n  with tf.compat.v1.variable_scope(name, reuse=True):\n    return tf.compat.v1.get_variable('bias_weights')\n\n\ndef get_linear_model_column_var(column, name='linear_model'):\n  return tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,\n                            name + '/' + column.name)[0]\n\n\nclass BaseFeatureColumnForTests(tf.compat.v2.__internal__.feature_column.FeatureColumn):\n  \"\"\"A base FeatureColumn useful to avoid boiler-plate in tests.\n\n  Provides dummy implementations for abstract methods that raise ValueError in\n  order to avoid re-defining all abstract methods for each test sub-class.\n  \"\"\"\n\n  @property\n  def parents(self):\n    raise ValueError('Should not use this method.')\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None, columns_by_name=None):\n    raise ValueError('Should not use this method.')\n\n  def get_config(self):\n    raise ValueError('Should not use this method.')\n\n\nclass SortableFeatureColumnTest(tf.test.TestCase):\n\n  @test_util.run_deprecated_v1\n  def test_linear_model(self):\n    price = tf.feature_column.numeric_column('price')\n    with tf.Graph().as_default():\n      features = {'price': [[1.], [5.]]}\n      model = linear.LinearModel([price])\n      predictions = model(features)\n      price_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([0.], self.evaluate(bias))\n        self.assertAllClose([[0.]], self.evaluate(price_var))\n        self.assertAllClose([[0.], [0.]], self.evaluate(predictions))\n        sess.run(price_var.assign([[10.]]))\n        self.assertAllClose([[10.], [50.]], self.evaluate(predictions))\n\n  @test_util.run_deprecated_v1\n  def test_linear_model_sanitizes_scope_names(self):\n    price = tf.feature_column.numeric_column('price > 100')\n    with tf.Graph().as_default():\n      features = {'price > 100': [[1.], [5.]]}\n      model = linear.LinearModel([price])\n      predictions = model(features)\n      price_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([0.], self.evaluate(bias))\n        self.assertAllClose([[0.]], self.evaluate(price_var))\n        self.assertAllClose([[0.], [0.]], self.evaluate(predictions))\n        sess.run(price_var.assign([[10.]]))\n        self.assertAllClose([[10.], [50.]], self.evaluate(predictions))\n\n\nclass BucketizedColumnTest(tf.test.TestCase):\n\n  def test_linear_model_one_input_value(self):\n    \"\"\"Tests linear_model() for input with shape=[1].\"\"\"\n    price = tf.feature_column.numeric_column('price', shape=[1])\n    bucketized_price = tf.feature_column.bucketized_column(price, boundaries=[0, 2, 4, 6])\n    with tf.Graph().as_default():\n      features = {'price': [[-1.], [1.], [5.], [6.]]}\n      model = linear.LinearModel([bucketized_price])\n      predictions = model(features)\n      bucketized_price_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([0.], self.evaluate(bias))\n        # One weight variable per bucket, all initialized to zero.\n        self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],\n                            self.evaluate(bucketized_price_var))\n        self.assertAllClose([[0.], [0.], [0.], [0.]],\n                            self.evaluate(predictions))\n        sess.run(\n            bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))\n        # price -1. is in the 0th bucket, whose weight is 10.\n        # price 1. is in the 1st bucket, whose weight is 20.\n        # price 5. is in the 3rd bucket, whose weight is 40.\n        # price 6. is in the 4th bucket, whose weight is 50.\n        self.assertAllClose([[10.], [20.], [40.], [50.]],\n                            self.evaluate(predictions))\n        sess.run(bias.assign([1.]))\n        self.assertAllClose([[11.], [21.], [41.], [51.]],\n                            self.evaluate(predictions))\n\n  def test_linear_model_two_input_values(self):\n    \"\"\"Tests linear_model() for input with shape=[2].\"\"\"\n    price = tf.feature_column.numeric_column('price', shape=[2])\n    bucketized_price = tf.feature_column.bucketized_column(price, boundaries=[0, 2, 4, 6])\n    with tf.Graph().as_default():\n      features = {'price': [[-1., 1.], [5., 6.]]}\n      model = linear.LinearModel([bucketized_price])\n      predictions = model(features)\n      bucketized_price_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([0.], self.evaluate(bias))\n        # One weight per bucket per input column, all initialized to zero.\n        self.assertAllClose(\n            [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],\n            self.evaluate(bucketized_price_var))\n        self.assertAllClose([[0.], [0.]], self.evaluate(predictions))\n        sess.run(\n            bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],\n                                         [60.], [70.], [80.], [90.], [100.]]))\n        # 1st example:\n        #   price -1. is in the 0th bucket, whose weight is 10.\n        #   price 1. is in the 6th bucket, whose weight is 70.\n        # 2nd example:\n        #   price 5. is in the 3rd bucket, whose weight is 40.\n        #   price 6. is in the 9th bucket, whose weight is 100.\n        self.assertAllClose([[80.], [140.]], self.evaluate(predictions))\n        sess.run(bias.assign([1.]))\n        self.assertAllClose([[81.], [141.]], self.evaluate(predictions))\n\n\nclass HashedCategoricalColumnTest(tf.test.TestCase):\n\n  @test_util.run_deprecated_v1\n  def test_linear_model(self):\n    wire_column = tf.feature_column.categorical_column_with_hash_bucket('wire', 4)\n    self.assertEqual(4, wire_column.num_buckets)\n    with tf.Graph().as_default():\n      model = linear.LinearModel((wire_column,))\n      predictions = model({\n          wire_column.name:\n              tf.compat.v1.SparseTensorValue(\n                  indices=((0, 0), (1, 0), (1, 1)),\n                  values=('marlo', 'skywalker', 'omar'),\n                  dense_shape=(2, 2))\n      })\n      wire_var, bias = model.variables\n\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.evaluate(tf.compat.v1.tables_initializer())\n\n      self.assertAllClose((0.,), self.evaluate(bias))\n      self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), self.evaluate(wire_var))\n      self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions))\n      self.evaluate(wire_var.assign(((1.,), (2.,), (3.,), (4.,))))\n      # 'marlo' -> 3: wire_var[3] = 4\n      # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6\n      self.assertAllClose(((4.,), (6.,)), self.evaluate(predictions))\n\n\nclass CrossedColumnTest(tf.test.TestCase):\n\n  @test_util.run_deprecated_v1\n  def test_linear_model(self):\n    \"\"\"Tests linear_model.\n\n    Uses data from test_get_sparse_tensors_simple.\n    \"\"\"\n    a = tf.feature_column.numeric_column('a', dtype=tf.int32, shape=(2,))\n    b = tf.feature_column.bucketized_column(a, boundaries=(0, 1))\n    crossed = tf.feature_column.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)\n    with tf.Graph().as_default():\n      model = linear.LinearModel((crossed,))\n      predictions = model({\n          'a':\n              tf.compat.v2.constant(((-1., .5), (.5, 1.))),\n          'c':\n              tf.sparse.SparseTensor(\n                  indices=((0, 0), (1, 0), (1, 1)),\n                  values=['cA', 'cB', 'cC'],\n                  dense_shape=(2, 2)),\n      })\n      crossed_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose((0.,), self.evaluate(bias))\n        self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),\n                            self.evaluate(crossed_var))\n        self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions))\n        sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))\n        # Expected ids after cross = (1, 0, 1, 3, 4, 2)\n        self.assertAllClose(((3.,), (14.,)), self.evaluate(predictions))\n        sess.run(bias.assign((.1,)))\n        self.assertAllClose(((3.1,), (14.1,)), self.evaluate(predictions))\n\n  def test_linear_model_with_weights(self):\n\n    class _TestColumnWithWeights(BaseFeatureColumnForTests,\n                                 fc.CategoricalColumn):\n      \"\"\"Produces sparse IDs and sparse weights.\"\"\"\n\n      @property\n      def _is_v2_column(self):\n        return True\n\n      @property\n      def name(self):\n        return 'test_column'\n\n      @property\n      def parse_example_spec(self):\n        return {\n            self.name:\n                tf.io.VarLenFeature(tf.int32),\n            '{}_weights'.format(self.name):\n                tf.io.VarLenFeature(tf.float32),\n        }\n\n      @property\n      def num_buckets(self):\n        return 5\n\n      def transform_feature(self, transformation_cache, state_manager):\n        return (transformation_cache.get(self.name, state_manager),\n                transformation_cache.get('{}_weights'.format(self.name),\n                                         state_manager))\n\n      def get_sparse_tensors(self, transformation_cache, state_manager):\n        \"\"\"Populates both id_tensor and weight_tensor.\"\"\"\n        ids_and_weights = transformation_cache.get(self, state_manager)\n        return fc.CategoricalColumn.IdWeightPair(\n            id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])\n\n    t = _TestColumnWithWeights()\n    crossed = tf.feature_column.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)\n    with tf.Graph().as_default():\n      with self.assertRaisesRegexp(\n          ValueError,\n          'crossed_column does not support weight_tensor.*{}'.format(t.name)):\n        model = linear.LinearModel((crossed,))\n        model({\n            t.name:\n                tf.sparse.SparseTensor(\n                    indices=((0, 0), (1, 0), (1, 1)),\n                    values=[0, 1, 2],\n                    dense_shape=(2, 2)),\n            '{}_weights'.format(t.name):\n                tf.sparse.SparseTensor(\n                    indices=((0, 0), (1, 0), (1, 1)),\n                    values=[1., 10., 2.],\n                    dense_shape=(2, 2)),\n            'c':\n                tf.sparse.SparseTensor(\n                    indices=((0, 0), (1, 0), (1, 1)),\n                    values=['cA', 'cB', 'cC'],\n                    dense_shape=(2, 2)),\n        })\n\n\nclass LinearModelTest(tf.test.TestCase):\n\n  def test_raises_if_empty_feature_columns(self):\n    with self.assertRaisesRegexp(ValueError,\n                                 'feature_columns must not be empty'):\n      linear.LinearModel(feature_columns=[])\n\n  def test_should_be_feature_column(self):\n    with self.assertRaisesRegexp(ValueError, 'must be a FeatureColumn'):\n      linear.LinearModel(feature_columns='NotSupported')\n\n  def test_should_be_dense_or_categorical_column(self):\n\n    class NotSupportedColumn(BaseFeatureColumnForTests):\n\n      @property\n      def _is_v2_column(self):\n        return True\n\n      @property\n      def name(self):\n        return 'NotSupportedColumn'\n\n      def transform_feature(self, transformation_cache, state_manager):\n        pass\n\n      @property\n      def parse_example_spec(self):\n        pass\n\n    with self.assertRaisesRegexp(\n        ValueError, 'must be either a DenseColumn or CategoricalColumn'):\n      linear.LinearModel(feature_columns=[NotSupportedColumn()])\n\n  def test_does_not_support_dict_columns(self):\n    with self.assertRaisesRegexp(\n        ValueError, 'Expected feature_columns to be iterable, found dict.'):\n      linear.LinearModel(feature_columns={'a': tf.feature_column.numeric_column('a')})\n\n  def test_raises_if_duplicate_name(self):\n    with self.assertRaisesRegexp(\n        ValueError, 'Duplicate feature column name found for columns'):\n      linear.LinearModel(\n          feature_columns=[tf.feature_column.numeric_column('a'),\n                           tf.feature_column.numeric_column('a')])\n\n  def test_not_dict_input_features(self):\n    price = tf.feature_column.numeric_column('price')\n    with tf.Graph().as_default():\n      features = [[1.], [5.]]\n      model = linear.LinearModel([price])\n      with self.assertRaisesRegexp(ValueError, 'We expected a dictionary here'):\n        model(features)\n\n  def test_dense_bias(self):\n    price = tf.feature_column.numeric_column('price')\n    with tf.Graph().as_default():\n      features = {'price': [[1.], [5.]]}\n      model = linear.LinearModel([price])\n      predictions = model(features)\n      price_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([0.], self.evaluate(bias))\n        sess.run(price_var.assign([[10.]]))\n        sess.run(bias.assign([5.]))\n        self.assertAllClose([[15.], [55.]], self.evaluate(predictions))\n\n  def test_sparse_bias(self):\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    with tf.Graph().as_default():\n      wire_tensor = tf.sparse.SparseTensor(\n          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]\n          indices=[[0, 0], [1, 0], [1, 1]],\n          dense_shape=[2, 2])\n      features = {'wire_cast': wire_tensor}\n      model = linear.LinearModel([wire_cast])\n      predictions = model(features)\n      wire_cast_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([0.], self.evaluate(bias))\n        self.assertAllClose([[0.], [0.], [0.], [0.]],\n                            self.evaluate(wire_cast_var))\n        sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))\n        sess.run(bias.assign([5.]))\n        self.assertAllClose([[1005.], [10015.]], self.evaluate(predictions))\n\n  def test_dense_and_sparse_bias(self):\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    price = tf.feature_column.numeric_column('price')\n    with tf.Graph().as_default():\n      wire_tensor = tf.sparse.SparseTensor(\n          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]\n          indices=[[0, 0], [1, 0], [1, 1]],\n          dense_shape=[2, 2])\n      features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}\n      model = linear.LinearModel([wire_cast, price])\n      predictions = model(features)\n      price_var, wire_cast_var, bias = model.variables\n      with _initialized_session() as sess:\n        sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))\n        sess.run(bias.assign([5.]))\n        sess.run(price_var.assign([[10.]]))\n        self.assertAllClose([[1015.], [10065.]], self.evaluate(predictions))\n\n  def test_dense_and_sparse_column(self):\n    \"\"\"When the column is both dense and sparse, uses sparse tensors.\"\"\"\n\n    class _DenseAndSparseColumn(BaseFeatureColumnForTests, tf.compat.v2.__internal__.feature_column.DenseColumn,\n                                fc.CategoricalColumn):\n\n      @property\n      def _is_v2_column(self):\n        return True\n\n      @property\n      def name(self):\n        return 'dense_and_sparse_column'\n\n      @property\n      def parse_example_spec(self):\n        return {self.name: tf.io.VarLenFeature(self.dtype)}\n\n      def transform_feature(self, transformation_cache, state_manager):\n        return transformation_cache.get(self.name, state_manager)\n\n      @property\n      def variable_shape(self):\n        raise ValueError('Should not use this method.')\n\n      def get_dense_tensor(self, transformation_cache, state_manager):\n        raise ValueError('Should not use this method.')\n\n      @property\n      def num_buckets(self):\n        return 4\n\n      def get_sparse_tensors(self, transformation_cache, state_manager):\n        sp_tensor = tf.sparse.SparseTensor(\n            indices=[[0, 0], [1, 0], [1, 1]],\n            values=[2, 0, 3],\n            dense_shape=[2, 2])\n        return fc.CategoricalColumn.IdWeightPair(sp_tensor, None)\n\n    dense_and_sparse_column = _DenseAndSparseColumn()\n    with tf.Graph().as_default():\n      sp_tensor = tf.sparse.SparseTensor(\n          values=['omar', 'stringer', 'marlo'],\n          indices=[[0, 0], [1, 0], [1, 1]],\n          dense_shape=[2, 2])\n      features = {dense_and_sparse_column.name: sp_tensor}\n      model = linear.LinearModel([dense_and_sparse_column])\n      predictions = model(features)\n      dense_and_sparse_column_var, bias = model.variables\n      with _initialized_session() as sess:\n        sess.run(\n            dense_and_sparse_column_var.assign([[10.], [100.], [1000.],\n                                                [10000.]]))\n        sess.run(bias.assign([5.]))\n        self.assertAllClose([[1005.], [10015.]], self.evaluate(predictions))\n\n  def test_dense_multi_output(self):\n    price = tf.feature_column.numeric_column('price')\n    with tf.Graph().as_default():\n      features = {'price': [[1.], [5.]]}\n      model = linear.LinearModel([price], units=3)\n      predictions = model(features)\n      price_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose(np.zeros((3,)), self.evaluate(bias))\n        self.assertAllClose(np.zeros((1, 3)), self.evaluate(price_var))\n        sess.run(price_var.assign([[10., 100., 1000.]]))\n        sess.run(bias.assign([5., 6., 7.]))\n        self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],\n                            self.evaluate(predictions))\n\n  def test_sparse_multi_output(self):\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    with tf.Graph().as_default():\n      wire_tensor = tf.sparse.SparseTensor(\n          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]\n          indices=[[0, 0], [1, 0], [1, 1]],\n          dense_shape=[2, 2])\n      features = {'wire_cast': wire_tensor}\n      model = linear.LinearModel([wire_cast], units=3)\n      predictions = model(features)\n      wire_cast_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose(np.zeros((3,)), self.evaluate(bias))\n        self.assertAllClose(np.zeros((4, 3)), self.evaluate(wire_cast_var))\n        sess.run(\n            wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],\n                                  [1000., 1100., 1200.],\n                                  [10000., 11000., 12000.]]))\n        sess.run(bias.assign([5., 6., 7.]))\n        self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],\n                            self.evaluate(predictions))\n\n  def test_dense_multi_dimension(self):\n    price = tf.feature_column.numeric_column('price', shape=2)\n    with tf.Graph().as_default():\n      features = {'price': [[1., 2.], [5., 6.]]}\n      model = linear.LinearModel([price])\n      predictions = model(features)\n      price_var, _ = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([[0.], [0.]], self.evaluate(price_var))\n        sess.run(price_var.assign([[10.], [100.]]))\n        self.assertAllClose([[210.], [650.]], self.evaluate(predictions))\n\n  def test_sparse_multi_rank(self):\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    with tf.Graph().as_default():\n      wire_tensor = tf.compat.v1.sparse_placeholder(tf.string)\n      wire_value = tf.compat.v1.SparseTensorValue(\n          values=['omar', 'stringer', 'marlo', 'omar'],  # hashed = [2, 0, 3, 2]\n          indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],\n          dense_shape=[2, 2, 2])\n      features = {'wire_cast': wire_tensor}\n      model = linear.LinearModel([wire_cast])\n      predictions = model(features)\n      wire_cast_var, _ = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose(np.zeros((4, 1)), self.evaluate(wire_cast_var))\n        self.assertAllClose(\n            np.zeros((2, 1)),\n            predictions.eval(feed_dict={wire_tensor: wire_value}))\n        sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))\n        self.assertAllClose(\n            [[1010.], [11000.]],\n            predictions.eval(feed_dict={wire_tensor: wire_value}))\n\n  def test_sparse_combiner(self):\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    with tf.Graph().as_default():\n      wire_tensor = tf.sparse.SparseTensor(\n          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]\n          indices=[[0, 0], [1, 0], [1, 1]],\n          dense_shape=[2, 2])\n      features = {'wire_cast': wire_tensor}\n      model = linear.LinearModel([wire_cast], sparse_combiner='mean')\n      predictions = model(features)\n      wire_cast_var, bias = model.variables\n      with _initialized_session() as sess:\n        sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))\n        sess.run(bias.assign([5.]))\n        self.assertAllClose([[1005.], [5010.]], self.evaluate(predictions))\n\n  def test_sparse_combiner_sqrtn(self):\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    with tf.Graph().as_default():\n      wire_tensor = tf.sparse.SparseTensor(\n          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]\n          indices=[[0, 0], [1, 0], [1, 1]],\n          dense_shape=[2, 2])\n      features = {'wire_cast': wire_tensor}\n      model = linear.LinearModel([wire_cast], sparse_combiner='sqrtn')\n      predictions = model(features)\n      wire_cast_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.evaluate(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))\n        self.evaluate(bias.assign([5.]))\n        self.assertAllClose([[1005.], [7083.139]], self.evaluate(predictions))\n\n  def test_sparse_combiner_with_negative_weights(self):\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    wire_cast_weights = tf.feature_column.weighted_categorical_column(wire_cast, 'weights')\n\n    with tf.Graph().as_default():\n      wire_tensor = tf.sparse.SparseTensor(\n          values=['omar', 'stringer', 'marlo'],  # hashed to = [2, 0, 3]\n          indices=[[0, 0], [1, 0], [1, 1]],\n          dense_shape=[2, 2])\n      features = {\n          'wire_cast': wire_tensor,\n          'weights': tf.compat.v2.constant([[1., 1., -1.0]])\n      }\n      model = linear.LinearModel([wire_cast_weights], sparse_combiner='sum')\n      predictions = model(features)\n      wire_cast_var, bias = model.variables\n      with _initialized_session() as sess:\n        sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))\n        sess.run(bias.assign([5.]))\n        self.assertAllClose([[1005.], [-9985.]], self.evaluate(predictions))\n\n  def test_dense_multi_dimension_multi_output(self):\n    price = tf.feature_column.numeric_column('price', shape=2)\n    with tf.Graph().as_default():\n      features = {'price': [[1., 2.], [5., 6.]]}\n      model = linear.LinearModel([price], units=3)\n      predictions = model(features)\n      price_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose(np.zeros((3,)), self.evaluate(bias))\n        self.assertAllClose(np.zeros((2, 3)), self.evaluate(price_var))\n        sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))\n        sess.run(bias.assign([2., 3., 4.]))\n        self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],\n                            self.evaluate(predictions))\n\n  def test_raises_if_shape_mismatch(self):\n    price = tf.feature_column.numeric_column('price', shape=2)\n    with tf.Graph().as_default():\n      features = {'price': [[1.], [5.]]}\n      with self.assertRaisesRegexp(\n          Exception,\n          r'Cannot reshape a tensor with 2 elements to shape \\[2,2\\]'):\n        model = linear.LinearModel([price])\n        model(features)\n\n  def test_dense_reshaping(self):\n    price = tf.feature_column.numeric_column('price', shape=[1, 2])\n    with tf.Graph().as_default():\n      features = {'price': [[[1., 2.]], [[5., 6.]]]}\n      model = linear.LinearModel([price])\n      predictions = model(features)\n      price_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([0.], self.evaluate(bias))\n        self.assertAllClose([[0.], [0.]], self.evaluate(price_var))\n        self.assertAllClose([[0.], [0.]], self.evaluate(predictions))\n        sess.run(price_var.assign([[10.], [100.]]))\n        self.assertAllClose([[210.], [650.]], self.evaluate(predictions))\n\n  def test_dense_multi_column(self):\n    price1 = tf.feature_column.numeric_column('price1', shape=2)\n    price2 = tf.feature_column.numeric_column('price2')\n    with tf.Graph().as_default():\n      features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}\n      model = linear.LinearModel([price1, price2])\n      predictions = model(features)\n      price1_var, price2_var, bias = model.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([0.], self.evaluate(bias))\n        self.assertAllClose([[0.], [0.]], self.evaluate(price1_var))\n        self.assertAllClose([[0.]], self.evaluate(price2_var))\n        self.assertAllClose([[0.], [0.]], self.evaluate(predictions))\n        sess.run(price1_var.assign([[10.], [100.]]))\n        sess.run(price2_var.assign([[1000.]]))\n        sess.run(bias.assign([7.]))\n        self.assertAllClose([[3217.], [4657.]], self.evaluate(predictions))\n\n  def test_dense_trainable_default(self):\n    price = tf.feature_column.numeric_column('price')\n    with tf.Graph().as_default() as g:\n      features = {'price': [[1.], [5.]]}\n      model = linear.LinearModel([price])\n      model(features)\n      price_var, bias = model.variables\n      trainable_vars = g.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      self.assertIn(bias, trainable_vars)\n      self.assertIn(price_var, trainable_vars)\n\n  def test_sparse_trainable_default(self):\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    with tf.Graph().as_default() as g:\n      wire_tensor = tf.sparse.SparseTensor(\n          values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])\n      features = {'wire_cast': wire_tensor}\n      model = linear.LinearModel([wire_cast])\n      model(features)\n      trainable_vars = g.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      wire_cast_var, bias = model.variables\n      self.assertIn(bias, trainable_vars)\n      self.assertIn(wire_cast_var, trainable_vars)\n\n  def test_dense_trainable_false(self):\n    price = tf.feature_column.numeric_column('price')\n    with tf.Graph().as_default() as g:\n      features = {'price': [[1.], [5.]]}\n      model = linear.LinearModel([price], trainable=False)\n      model(features)\n      trainable_vars = g.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      self.assertEqual([], trainable_vars)\n\n  def test_sparse_trainable_false(self):\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    with tf.Graph().as_default() as g:\n      wire_tensor = tf.sparse.SparseTensor(\n          values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])\n      features = {'wire_cast': wire_tensor}\n      model = linear.LinearModel([wire_cast], trainable=False)\n      model(features)\n      trainable_vars = g.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      self.assertEqual([], trainable_vars)\n\n  def test_column_order(self):\n    price_a = tf.feature_column.numeric_column('price_a')\n    price_b = tf.feature_column.numeric_column('price_b')\n    wire_cast = tf.feature_column.categorical_column_with_hash_bucket('wire_cast', 4)\n    with tf.Graph().as_default():\n      features = {\n          'price_a': [[1.]],\n          'price_b': [[3.]],\n          'wire_cast':\n              tf.sparse.SparseTensor(\n                  values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])\n      }\n      model = linear.LinearModel([price_a, wire_cast, price_b])\n      model(features)\n\n      my_vars = model.variables\n      self.assertIn('price_a', my_vars[0].name)\n      self.assertIn('price_b', my_vars[1].name)\n      self.assertIn('wire_cast', my_vars[2].name)\n\n    with tf.Graph().as_default():\n      features = {\n          'price_a': [[1.]],\n          'price_b': [[3.]],\n          'wire_cast':\n              tf.sparse.SparseTensor(\n                  values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])\n      }\n      model = linear.LinearModel([wire_cast, price_b, price_a])\n      model(features)\n\n      my_vars = model.variables\n      self.assertIn('price_a', my_vars[0].name)\n      self.assertIn('price_b', my_vars[1].name)\n      self.assertIn('wire_cast', my_vars[2].name)\n\n  def test_variable_names(self):\n    price1 = tf.feature_column.numeric_column('price1')\n    dense_feature = tf.feature_column.numeric_column('dense_feature')\n    dense_feature_bucketized = tf.feature_column.bucketized_column(\n        dense_feature, boundaries=[0.])\n    some_sparse_column = tf.feature_column.categorical_column_with_hash_bucket(\n        'sparse_feature', hash_bucket_size=5)\n    some_embedding_column = tf.feature_column.embedding_column(\n        some_sparse_column, dimension=10)\n    all_cols = [price1, dense_feature_bucketized, some_embedding_column]\n\n    with tf.Graph().as_default():\n      model = linear.LinearModel(all_cols)\n      features = {\n          'price1': [[3.], [4.]],\n          'dense_feature': [[-1.], [4.]],\n          'sparse_feature': [['a'], ['x']],\n      }\n      model(features)\n      for var in model.variables:\n        self.assertIsInstance(var, tf.Variable)\n      variable_names = [var.name for var in model.variables]\n      self.assertCountEqual([\n          'linear_model/dense_feature_bucketized/weights:0',\n          'linear_model/price1/weights:0',\n          'linear_model/sparse_feature_embedding/embedding_weights:0',\n          'linear_model/sparse_feature_embedding/weights:0',\n          'linear_model/bias_weights:0',\n      ], variable_names)\n\n  def test_fit_and_predict(self):\n    columns = [tf.feature_column.numeric_column('a')]\n\n    model = linear.LinearModel(columns)\n    model.compile(\n        optimizer=tf.compat.v1.train.RMSPropOptimizer(1e-3),\n        loss='binary_crossentropy',\n        metrics=['accuracy'])\n\n    x = {'a': np.random.random((10, 1))}\n    y = np.random.randint(0, 2, size=(10, 1))\n    model.fit(x, y, epochs=1, batch_size=5)\n    model.fit(x, y, epochs=1, batch_size=5)\n    model.evaluate(x, y, batch_size=5)\n    model.predict(x, batch_size=5)\n\n  def test_static_batch_size_mismatch(self):\n    price1 = tf.feature_column.numeric_column('price1')\n    price2 = tf.feature_column.numeric_column('price2')\n    with tf.Graph().as_default():\n      features = {\n          'price1': [[1.], [5.], [7.]],  # batchsize = 3\n          'price2': [[3.], [4.]]  # batchsize = 2\n      }\n    with self.assertRaisesRegexp(\n        ValueError,\n        r'Batch size \\(first dimension\\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string\n      model = linear.LinearModel([price1, price2])\n      model(features)\n\n  def test_subset_of_static_batch_size_mismatch(self):\n    price1 = tf.feature_column.numeric_column('price1')\n    price2 = tf.feature_column.numeric_column('price2')\n    price3 = tf.feature_column.numeric_column('price3')\n    with tf.Graph().as_default():\n      features = {\n          'price1': tf.compat.v1.placeholder(dtype=tf.int64),  # batchsize = 3\n          'price2': [[3.], [4.]],  # batchsize = 2\n          'price3': [[3.], [4.], [5.]]  # batchsize = 3\n      }\n      with self.assertRaisesRegexp(\n          ValueError,\n          r'Batch size \\(first dimension\\) of each feature must be same.'):  # pylint: disable=anomalous-backslash-in-string\n        model = linear.LinearModel([price1, price2, price3])\n        model(features)\n\n  def test_runtime_batch_size_mismatch(self):\n    price1 = tf.feature_column.numeric_column('price1')\n    price2 = tf.feature_column.numeric_column('price2')\n    with tf.Graph().as_default():\n      features = {\n          'price1': tf.compat.v1.placeholder(dtype=tf.int64),  # batchsize = 3\n          'price2': [[3.], [4.]]  # batchsize = 2\n      }\n      model = linear.LinearModel([price1, price2])\n      predictions = model(features)\n      with _initialized_session() as sess:\n        with self.assertRaisesRegexp(tf.errors.OpError,\n                                     'must have the same size and shape'):\n          sess.run(\n              predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})\n\n  def test_runtime_batch_size_matches(self):\n    price1 = tf.feature_column.numeric_column('price1')\n    price2 = tf.feature_column.numeric_column('price2')\n    with tf.Graph().as_default():\n      features = {\n          'price1': tf.compat.v1.placeholder(dtype=tf.int64),  # batchsize = 2\n          'price2': tf.compat.v1.placeholder(dtype=tf.int64),  # batchsize = 2\n      }\n      model = linear.LinearModel([price1, price2])\n      predictions = model(features)\n      with _initialized_session() as sess:\n        sess.run(\n            predictions,\n            feed_dict={\n                features['price1']: [[1.], [5.]],\n                features['price2']: [[1.], [5.]],\n            })\n\n  @test_util.run_deprecated_v1\n  def test_with_1d_sparse_tensor(self):\n    price = tf.feature_column.numeric_column('price')\n    price_buckets = tf.feature_column.bucketized_column(\n        price, boundaries=[\n            0.,\n            10.,\n            100.,\n        ])\n    body_style = tf.feature_column.categorical_column_with_vocabulary_list(\n        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])\n\n    # Provides 1-dim tensor and dense tensor.\n    features = {\n        'price':\n            tf.compat.v2.constant([\n                -1.,\n                12.,\n            ]),\n        'body-style':\n            tf.sparse.SparseTensor(\n                indices=((0,), (1,)),\n                values=('sedan', 'hardtop'),\n                dense_shape=(2,)),\n    }\n    self.assertEqual(1, features['price'].shape.ndims)\n    self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])\n\n    model = linear.LinearModel([price_buckets, body_style])\n    net = model(features)\n    with _initialized_session() as sess:\n      body_style_var, price_buckets_var, bias = model.variables\n\n      sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))\n      sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))\n      sess.run(bias.assign([5.]))\n\n      self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],\n                          self.evaluate(net))\n\n  @test_util.run_deprecated_v1\n  def test_with_1d_unknown_shape_sparse_tensor(self):\n    price = tf.feature_column.numeric_column('price')\n    price_buckets = tf.feature_column.bucketized_column(\n        price, boundaries=[\n            0.,\n            10.,\n            100.,\n        ])\n    body_style = tf.feature_column.categorical_column_with_vocabulary_list(\n        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])\n    country = tf.feature_column.categorical_column_with_vocabulary_list(\n        'country', vocabulary_list=['US', 'JP', 'CA'])\n\n    # Provides 1-dim tensor and dense tensor.\n    features = {\n        'price': tf.compat.v1.placeholder(tf.float32),\n        'body-style': tf.compat.v1.sparse_placeholder(tf.string),\n        'country': tf.compat.v1.placeholder(tf.string),\n    }\n    self.assertIsNone(features['price'].shape.ndims)\n    self.assertIsNone(features['body-style'].get_shape().ndims)\n\n    price_data = np.array([-1., 12.])\n    body_style_data = tf.compat.v1.SparseTensorValue(\n        indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))\n    country_data = np.array(['US', 'CA'])\n\n    model = linear.LinearModel([price_buckets, body_style, country])\n    net = model(features)\n    body_style_var, _, price_buckets_var, bias = model.variables\n    with _initialized_session() as sess:\n      sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))\n      sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))\n      sess.run(bias.assign([5.]))\n\n      self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],\n                          sess.run(\n                              net,\n                              feed_dict={\n                                  features['price']: price_data,\n                                  features['body-style']: body_style_data,\n                                  features['country']: country_data\n                              }))\n\n  @test_util.run_deprecated_v1\n  def test_with_rank_0_feature(self):\n    price = tf.feature_column.numeric_column('price')\n    features = {\n        'price': tf.compat.v2.constant(0),\n    }\n    self.assertEqual(0, features['price'].shape.ndims)\n\n    # Static rank 0 should fail\n    with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):\n      model = linear.LinearModel([price])\n      model(features)\n\n    # Dynamic rank 0 should fail\n    features = {\n        'price': tf.compat.v1.placeholder(tf.float32),\n    }\n    model = linear.LinearModel([price])\n    net = model(features)\n    self.assertEqual(1, net.shape[1])\n    with _initialized_session() as sess:\n      with self.assertRaisesOpError('Feature .* cannot have rank 0'):\n        sess.run(net, feed_dict={features['price']: np.array(1)})\n\n  def test_multiple_linear_models(self):\n    price = tf.feature_column.numeric_column('price')\n    with tf.Graph().as_default():\n      features1 = {'price': [[1.], [5.]]}\n      features2 = {'price': [[2.], [10.]]}\n      model1 = linear.LinearModel([price])\n      model2 = linear.LinearModel([price])\n      predictions1 = model1(features1)\n      predictions2 = model2(features2)\n      price_var1, bias1 = model1.variables\n      price_var2, bias2 = model2.variables\n      with _initialized_session() as sess:\n        self.assertAllClose([0.], self.evaluate(bias1))\n        sess.run(price_var1.assign([[10.]]))\n        sess.run(bias1.assign([5.]))\n        self.assertAllClose([[15.], [55.]], self.evaluate(predictions1))\n        self.assertAllClose([0.], self.evaluate(bias2))\n        sess.run(price_var2.assign([[10.]]))\n        sess.run(bias2.assign([5.]))\n        self.assertAllClose([[25.], [105.]], self.evaluate(predictions2))\n\n\nclass VocabularyFileCategoricalColumnTest(tf.test.TestCase):\n\n  def setUp(self):\n    super(VocabularyFileCategoricalColumnTest, self).setUp()\n\n    # Contains strings, character names from 'The Wire': omar, stringer, marlo\n    self._wire_vocabulary_file_name = os.path.join(\n        flags.FLAGS['test_srcdir'].value,\n        'org_tensorflow_estimator/tensorflow_estimator',\n        'python/estimator/canned/testdata/wire_vocabulary.txt')\n\n    # self._wire_vocabulary_file_name = test.test_src_dir_path(\n    #     'python/estimator/canned/testdata/wire_vocabulary.txt')\n    self._wire_vocabulary_size = 3\n\n  # TODO(scottzhu): Reenable test once the issue for reading test file is fixed.\n  @test_util.run_deprecated_v1\n  def DISABLED_test_linear_model(self):\n    wire_column = tf.compat.v1.feature_column.categorical_column_with_vocabulary_file(\n        key='wire',\n        vocabulary_file=self._wire_vocabulary_file_name,\n        vocabulary_size=self._wire_vocabulary_size,\n        num_oov_buckets=1)\n    self.assertEqual(4, wire_column.num_buckets)\n    with tf.Graph().as_default():\n      model = linear.LinearModel((wire_column,))\n      predictions = model({\n          wire_column.name:\n              tf.compat.v1.SparseTensorValue(\n                  indices=((0, 0), (1, 0), (1, 1)),\n                  values=('marlo', 'skywalker', 'omar'),\n                  dense_shape=(2, 2))\n      })\n      wire_var, bias = model.variables\n\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.evaluate(tf.compat.v1.tables_initializer())\n\n      self.assertAllClose((0.,), self.evaluate(bias))\n      self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), self.evaluate(wire_var))\n      self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions))\n      self.evaluate(wire_var.assign(((1.,), (2.,), (3.,), (4.,))))\n      # 'marlo' -> 2: wire_var[2] = 3\n      # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5\n      self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))\n\n\nclass VocabularyListCategoricalColumnTest(tf.test.TestCase):\n\n  @test_util.run_deprecated_v1\n  def test_linear_model(self):\n    wire_column = tf.feature_column.categorical_column_with_vocabulary_list(\n        key='aaa',\n        vocabulary_list=('omar', 'stringer', 'marlo'),\n        num_oov_buckets=1)\n    self.assertEqual(4, wire_column.num_buckets)\n    with tf.Graph().as_default():\n      model = linear.LinearModel((wire_column,))\n      predictions = model({\n          wire_column.name:\n              tf.compat.v1.SparseTensorValue(\n                  indices=((0, 0), (1, 0), (1, 1)),\n                  values=('marlo', 'skywalker', 'omar'),\n                  dense_shape=(2, 2))\n      })\n      wire_var, bias = model.variables\n\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.evaluate(tf.compat.v1.tables_initializer())\n\n      self.assertAllClose((0.,), self.evaluate(bias))\n      self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), self.evaluate(wire_var))\n      self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions))\n      self.evaluate(wire_var.assign(((1.,), (2.,), (3.,), (4.,))))\n      # 'marlo' -> 2: wire_var[2] = 3\n      # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5\n      self.assertAllClose(((3.,), (5.,)), self.evaluate(predictions))\n\n\nclass IdentityCategoricalColumnTest(tf.test.TestCase):\n\n  @test_util.run_deprecated_v1\n  def test_linear_model(self):\n    column = tf.feature_column.categorical_column_with_identity(key='aaa', num_buckets=3)\n    self.assertEqual(3, column.num_buckets)\n    with tf.Graph().as_default():\n      model = linear.LinearModel((column,))\n      predictions = model({\n          column.name:\n            tf.compat.v1.SparseTensorValue(\n                indices=((0, 0), (1, 0), (1, 1)),\n                values=(0, 2, 1),\n                dense_shape=(2, 2))\n      })\n      weight_var, bias = model.variables\n\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.evaluate(tf.compat.v1.tables_initializer())\n\n      self.assertAllClose((0.,), self.evaluate(bias))\n      self.assertAllClose(((0.,), (0.,), (0.,)), self.evaluate(weight_var))\n      self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions))\n      self.evaluate(weight_var.assign(((1.,), (2.,), (3.,))))\n      # weight_var[0] = 1\n      # weight_var[2] + weight_var[1] = 3+2 = 5\n      self.assertAllClose(((1.,), (5.,)), self.evaluate(predictions))\n\n\nclass IndicatorColumnTest(tf.test.TestCase):\n\n  @test_util.run_deprecated_v1\n  def test_linear_model(self):\n    animal = tf.feature_column.indicator_column(\n        tf.feature_column.categorical_column_with_identity('animal', num_buckets=4))\n    with tf.Graph().as_default():\n      features = {\n          'animal':\n              tf.sparse.SparseTensor(\n                  indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])\n      }\n\n      model = linear.LinearModel([animal])\n      predictions = model(features)\n      weight_var, _ = model.variables\n\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.evaluate(tf.compat.v1.tables_initializer())\n\n      # All should be zero-initialized.\n      self.assertAllClose([[0.], [0.], [0.], [0.]], self.evaluate(weight_var))\n      self.assertAllClose([[0.]], self.evaluate(predictions))\n      self.evaluate(weight_var.assign([[1.], [2.], [3.], [4.]]))\n      self.assertAllClose([[2. + 3.]], self.evaluate(predictions))\n\n\nclass EmbeddingColumnTest(tf.test.TestCase, parameterized.TestCase):\n\n  @test_util.run_deprecated_v1\n  def test_linear_model(self):\n    # Inputs.\n    batch_size = 4\n    vocabulary_size = 3\n    sparse_input = tf.compat.v1.SparseTensorValue(\n        # example 0, ids [2]\n        # example 1, ids [0, 1]\n        # example 2, ids []\n        # example 3, ids [1]\n        indices=((0, 0), (1, 0), (1, 4), (3, 0)),\n        values=(2, 0, 1, 1),\n        dense_shape=(batch_size, 5))\n\n    # Embedding variable.\n    embedding_dimension = 2\n    embedding_shape = (vocabulary_size, embedding_dimension)\n    zeros_embedding_values = np.zeros(embedding_shape)\n\n    def _initializer(shape, dtype, partition_info=None):\n      self.assertAllEqual(embedding_shape, shape)\n      self.assertEqual(tf.float32, dtype)\n      self.assertIsNone(partition_info)\n      return zeros_embedding_values\n\n    # Build columns.\n    categorical_column = tf.feature_column.categorical_column_with_identity(\n        key='aaa', num_buckets=vocabulary_size)\n    embedding_column = tf.feature_column.embedding_column(\n        categorical_column,\n        dimension=embedding_dimension,\n        initializer=_initializer)\n\n    with tf.Graph().as_default():\n      model = linear.LinearModel((embedding_column,))\n      predictions = model({categorical_column.name: sparse_input})\n      expected_var_names = (\n          'linear_model/bias_weights:0',\n          'linear_model/aaa_embedding/weights:0',\n          'linear_model/aaa_embedding/embedding_weights:0',\n      )\n      self.assertCountEqual(\n          expected_var_names,\n          [v.name for v in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)])\n      trainable_vars = {\n          v.name: v\n          for v in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      }\n      self.assertCountEqual(expected_var_names, trainable_vars.keys())\n      bias = trainable_vars['linear_model/bias_weights:0']\n      embedding_weights = trainable_vars[\n          'linear_model/aaa_embedding/embedding_weights:0']\n      linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']\n\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.evaluate(tf.compat.v1.tables_initializer())\n\n      # Predictions with all zero weights.\n      self.assertAllClose(np.zeros((1,)), self.evaluate(bias))\n      self.assertAllClose(zeros_embedding_values,\n                          self.evaluate(embedding_weights))\n      self.assertAllClose(\n          np.zeros((embedding_dimension, 1)), self.evaluate(linear_weights))\n      self.assertAllClose(np.zeros((batch_size, 1)), self.evaluate(predictions))\n\n      # Predictions with all non-zero weights.\n      self.evaluate(\n          embedding_weights.assign((\n              (1., 2.),  # id 0\n              (3., 5.),  # id 1\n              (7., 11.)  # id 2\n          )))\n      self.evaluate(linear_weights.assign(((4.,), (6.,))))\n      # example 0, ids [2], embedding[0] = [7, 11]\n      # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]\n      # example 2, ids [], embedding[2] = [0, 0]\n      # example 3, ids [1], embedding[3] = [3, 5]\n      # sum(embeddings * linear_weights)\n      # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]\n      self.assertAllClose(((94.,), (29.,), (0.,), (42.,)),\n                          self.evaluate(predictions))\n\n\nclass SharedEmbeddingColumnTest(tf.test.TestCase, parameterized.TestCase):\n\n  @test_util.run_deprecated_v1\n  def test_linear_model(self):\n    # Inputs.\n    batch_size = 2\n    vocabulary_size = 3\n    # -1 values are ignored.\n    input_a = np.array([\n        [2, -1, -1],  # example 0, ids [2]\n        [0, 1, -1]\n    ])  # example 1, ids [0, 1]\n    input_b = np.array([\n        [0, -1, -1],  # example 0, ids [0]\n        [-1, -1, -1]\n    ])  # example 1, ids []\n\n    # Embedding variable.\n    embedding_dimension = 2\n    embedding_shape = (vocabulary_size, embedding_dimension)\n    zeros_embedding_values = np.zeros(embedding_shape)\n\n    def _initializer(shape, dtype, partition_info=None):\n      self.assertAllEqual(embedding_shape, shape)\n      self.assertEqual(tf.float32, dtype)\n      self.assertIsNone(partition_info)\n      return zeros_embedding_values\n\n    # Build columns.\n    categorical_column_a = tf.feature_column.categorical_column_with_identity(\n        key='aaa', num_buckets=vocabulary_size)\n    categorical_column_b = tf.feature_column.categorical_column_with_identity(\n        key='bbb', num_buckets=vocabulary_size)\n    embedding_column_a, embedding_column_b = tf.compat.v2.feature_column.shared_embeddings(\n        [categorical_column_a, categorical_column_b],\n        dimension=embedding_dimension,\n        initializer=_initializer)\n\n    with tf.Graph().as_default():\n      model = linear.LinearModel((embedding_column_a, embedding_column_b))\n      predictions = model({\n          categorical_column_a.name: input_a,\n          categorical_column_b.name: input_b\n      })\n\n      # Linear weights do not follow the column name. But this is a rare use\n      # case, and fixing it would add too much complexity to the code.\n      expected_var_names = (\n          'linear_model/bias_weights:0',\n          'linear_model/aaa_shared_embedding/weights:0',\n          'aaa_bbb_shared_embedding:0',\n          'linear_model/bbb_shared_embedding/weights:0',\n      )\n      self.assertCountEqual(\n          expected_var_names,\n          [v.name for v in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)])\n      trainable_vars = {\n          v.name: v\n          for v in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      }\n      self.assertCountEqual(expected_var_names, trainable_vars.keys())\n      bias = trainable_vars['linear_model/bias_weights:0']\n      embedding_weights = trainable_vars['aaa_bbb_shared_embedding:0']\n      linear_weights_a = trainable_vars[\n          'linear_model/aaa_shared_embedding/weights:0']\n      linear_weights_b = trainable_vars[\n          'linear_model/bbb_shared_embedding/weights:0']\n\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.evaluate(tf.compat.v1.tables_initializer())\n\n      # Predictions with all zero weights.\n      self.assertAllClose(np.zeros((1,)), self.evaluate(bias))\n      self.assertAllClose(zeros_embedding_values,\n                          self.evaluate(embedding_weights))\n      self.assertAllClose(\n          np.zeros((embedding_dimension, 1)), self.evaluate(linear_weights_a))\n      self.assertAllClose(\n          np.zeros((embedding_dimension, 1)), self.evaluate(linear_weights_b))\n      self.assertAllClose(np.zeros((batch_size, 1)), self.evaluate(predictions))\n\n      # Predictions with all non-zero weights.\n      self.evaluate(\n          embedding_weights.assign((\n              (1., 2.),  # id 0\n              (3., 5.),  # id 1\n              (7., 11.)  # id 2\n          )))\n      self.evaluate(linear_weights_a.assign(((4.,), (6.,))))\n      # example 0, ids [2], embedding[0] = [7, 11]\n      # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]\n      # sum(embeddings * linear_weights)\n      # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]\n      self.evaluate(linear_weights_b.assign(((3.,), (5.,))))\n      # example 0, ids [0], embedding[0] = [1, 2]\n      # example 1, ids [], embedding[1] = 0, 0]\n      # sum(embeddings * linear_weights)\n      # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]\n      self.assertAllClose([[94. + 13.], [29.]], self.evaluate(predictions))\n\n\nclass WeightedCategoricalColumnTest(tf.test.TestCase):\n\n  @test_util.run_deprecated_v1\n  def test_linear_model(self):\n    column = tf.feature_column.weighted_categorical_column(\n        categorical_column=tf.feature_column.categorical_column_with_identity(\n            key='ids', num_buckets=3),\n        weight_feature_key='values')\n    with tf.Graph().as_default():\n      model = linear.LinearModel((column,))\n      predictions = model({\n          'ids':\n              tf.compat.v1.SparseTensorValue(\n                  indices=((0, 0), (1, 0), (1, 1)),\n                  values=(0, 2, 1),\n                  dense_shape=(2, 2)),\n          'values':\n              tf.compat.v1.SparseTensorValue(\n                  indices=((0, 0), (1, 0), (1, 1)),\n                  values=(.5, 1., .1),\n                  dense_shape=(2, 2))\n      })\n      weight_var, bias = model.variables\n\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.evaluate(tf.compat.v1.tables_initializer())\n\n      self.assertAllClose((0.,), self.evaluate(bias))\n      self.assertAllClose(((0.,), (0.,), (0.,)), self.evaluate(weight_var))\n      self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions))\n      self.evaluate(weight_var.assign(((1.,), (2.,), (3.,))))\n      # weight_var[0] * weights[0, 0] = 1 * .5 = .5\n      # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]\n      # = 3*1 + 2*.1 = 3+.2 = 3.2\n      self.assertAllClose(((.5,), (3.2,)), self.evaluate(predictions))\n\n  def test_linear_model_mismatched_shape(self):\n    column = tf.feature_column.weighted_categorical_column(\n        categorical_column=tf.feature_column.categorical_column_with_identity(\n            key='ids', num_buckets=3),\n        weight_feature_key='values')\n    with tf.Graph().as_default():\n      with self.assertRaisesRegexp(ValueError,\n                                   r'Dimensions.*are not compatible'):\n        model = linear.LinearModel((column,))\n        model({\n            'ids':\n                tf.compat.v1.SparseTensorValue(\n                    indices=((0, 0), (1, 0), (1, 1)),\n                    values=(0, 2, 1),\n                    dense_shape=(2, 2)),\n            'values':\n                tf.compat.v1.SparseTensorValue(\n                    indices=((0, 0), (0, 1), (1, 0), (1, 1)),\n                    values=(.5, 11., 1., .1),\n                    dense_shape=(2, 2))\n        })\n\n  def test_linear_model_mismatched_dense_values(self):\n    column = tf.feature_column.weighted_categorical_column(\n        categorical_column=tf.feature_column.categorical_column_with_identity(\n            key='ids', num_buckets=3),\n        weight_feature_key='values')\n    with tf.Graph().as_default():\n      model = linear.LinearModel((column,), sparse_combiner='mean')\n      predictions = model({\n          'ids':\n              tf.compat.v1.SparseTensorValue(\n                  indices=((0, 0), (1, 0), (1, 1)),\n                  values=(0, 2, 1),\n                  dense_shape=(2, 2)),\n          'values': ((.5,), (1.,))\n      })\n      # Disabling the constant folding optimizer here since it changes the\n      # error message differently on CPU and GPU.\n      config = tf.compat.v1.ConfigProto()\n      config.graph_options.rewrite_options.constant_folding = (\n          rewriter_config_pb2.RewriterConfig.OFF)\n      with _initialized_session(config):\n        with self.assertRaisesRegexp(tf.errors.OpError, 'Incompatible shapes'):\n          self.evaluate(predictions)\n\n  def test_linear_model_mismatched_dense_shape(self):\n    column = tf.feature_column.weighted_categorical_column(\n        categorical_column=tf.feature_column.categorical_column_with_identity(\n            key='ids', num_buckets=3),\n        weight_feature_key='values')\n    with tf.Graph().as_default():\n      model = linear.LinearModel((column,))\n      predictions = model({\n          'ids':\n              tf.compat.v1.SparseTensorValue(\n                  indices=((0, 0), (1, 0), (1, 1)),\n                  values=(0, 2, 1),\n                  dense_shape=(2, 2)),\n          'values': ((.5,), (1.,), (.1,))\n      })\n      weight_var, bias = model.variables\n\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.evaluate(tf.compat.v1.tables_initializer())\n\n      self.assertAllClose((0.,), self.evaluate(bias))\n      self.assertAllClose(((0.,), (0.,), (0.,)), self.evaluate(weight_var))\n      self.assertAllClose(((0.,), (0.,)), self.evaluate(predictions))\n      self.evaluate(weight_var.assign(((1.,), (2.,), (3.,))))\n      # weight_var[0] * weights[0, 0] = 1 * .5 = .5\n      # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]\n      # = 3*1 + 2*.1 = 3+.2 = 3.2\n      self.assertAllClose(((.5,), (3.2,)), self.evaluate(predictions))\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass LinearModelLayerSerializationTest(tf.test.TestCase, parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      ('trainable', 6, 'mean', True, 'trainable'),\n      ('not_trainable', 10, 'sum', False, 'frozen'))\n  def test_get_config(self, units, sparse_combiner, trainable, name):\n    cols = [tf.feature_column.numeric_column('a'),\n            tf.feature_column.categorical_column_with_identity(key='b', num_buckets=3)]\n    layer = linear._LinearModelLayer(\n        cols, units=units, sparse_combiner=sparse_combiner,\n        trainable=trainable, name=name)\n    config = layer.get_config()\n\n    self.assertEqual(config['name'], layer.name)\n    self.assertEqual(config['trainable'], trainable)\n    self.assertEqual(config['units'], units)\n    self.assertEqual(config['sparse_combiner'], sparse_combiner)\n    self.assertLen(config['feature_columns'], 2)\n    self.assertEqual(\n        config['feature_columns'][0]['class_name'], 'NumericColumn')\n    self.assertEqual(\n        config['feature_columns'][1]['class_name'], 'IdentityCategoricalColumn')\n\n  @parameterized.named_parameters(\n      ('trainable', 6, 'mean', True, 'trainable'),\n      ('not_trainable', 10, 'sum', False, 'frozen'))\n  def test_from_config(self, units, sparse_combiner, trainable, name):\n    cols = [tf.feature_column.numeric_column('a'),\n            tf.feature_column.categorical_column_with_vocabulary_list(\n                'b', vocabulary_list=('1', '2', '3')),\n            tf.feature_column.categorical_column_with_hash_bucket(\n                key='c', hash_bucket_size=3)]\n    orig_layer = linear._LinearModelLayer(\n        cols, units=units, sparse_combiner=sparse_combiner,\n        trainable=trainable, name=name)\n    config = orig_layer.get_config()\n\n    new_layer = linear._LinearModelLayer.from_config(config)\n\n    self.assertEqual(new_layer.name, orig_layer.name)\n    self.assertEqual(new_layer._units, units)\n    self.assertEqual(new_layer._sparse_combiner, sparse_combiner)\n    self.assertEqual(new_layer.trainable, trainable)\n    self.assertLen(new_layer._feature_columns, 3)\n    self.assertEqual(new_layer._feature_columns[0].name, 'a')\n    self.assertEqual(\n        new_layer._feature_columns[1].vocabulary_list, ('1', '2', '3'))\n    self.assertEqual(new_layer._feature_columns[2].num_buckets, 3)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_optimizer/BUILD",
    "content": "# Placeholder: load py_library\nload(\"//tensorflow_estimator:estimator.bzl\", \"py_test\")\n\npackage(default_visibility = [\"//tensorflow_estimator:__subpackages__\"])\n\nlicenses([\"notice\"])\n\npy_test(\n    name = \"sdca_test\",\n    size = \"medium\",\n    srcs = [\"python/sdca_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    deps = [\n        \"//tensorflow_estimator/python/estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n        \"//tensorflow_estimator/python/estimator:linear\",\n    ],\n)\n\npy_library(\n    name = \"sdca_ops_py\",\n    srcs = [\n        \"__init__.py\",\n        \"python/utils/sdca_ops.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":sharded_mutable_dense_hashtable_py\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"sdca_ops_test\",\n    size = \"medium\",\n    srcs = [\"python/utils/sdca_ops_test.py\"],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    tags = [\n        \"no_gpu\",\n        \"no_pip_gpu\",\n    ],\n    deps = [\n        \":sdca_ops_py\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_library(\n    name = \"sharded_mutable_dense_hashtable_py\",\n    srcs = [\"python/utils/sharded_mutable_dense_hashtable.py\"],\n    srcs_version = \"PY3\",\n    deps = [\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n\npy_test(\n    name = \"sharded_mutable_dense_hashtable_test\",\n    size = \"small\",\n    srcs = [\"python/utils/sharded_mutable_dense_hashtable_test.py\"],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":sharded_mutable_dense_hashtable_py\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_keras_installed\",\n    ],\n)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_optimizer/__init__.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Ops for training linear models.\n\n## This package provides optimizers to train linear models.\n\n\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom tensorflow.python.util.all_util import remove_undocumented\nremove_undocumented(__name__)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_optimizer/doc/sdca.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"colab_type\": \"text\",\n        \"id\": \"DzJ8FQ_HsP7Q\"\n      },\n      \"source\": [\n        \"# Distributed SDCA\\n\",\n        \"\\n\",\n        \"$\\\\def\\\\a{\\\\alpha} \\\\def\\\\d{\\\\Delta\\\\a} \\\\def\\\\l{\\\\ell} \\\\def\\\\P{\\\\mathcal{P}}$\\n\",\n        \"We want to minimize on $K$ machines the following objective\\n\",\n        \"\\n\",\n        \"$$ P(w) = \\\\frac{1}{n}\\\\sum_{i=1}^n \\\\l_i(x_i^T w)+\\\\lambda g(w) $$\\n\",\n        \"\\n\",\n        \"By Fenchel duality, this is equivalent to maximizing its dual\\n\",\n        \"\\n\",\n        \"$$ D(\\\\a) = \\\\frac{1}{n} \\\\left(\\\\sum_{i=1}^n -\\\\l_i^\\\\star(-\\\\a_i)\\\\right) -\\\\lambda g^\\\\star\\\\left(\\\\tfrac{1}{\\\\lambda n} X\\\\a\\\\right) $$\\n\",\n        \"\\n\",\n        \"which can be done very efficiently on a single machine with SDCA [3].\\n\",\n        \"\\n\",\n        \"Here $f^\\\\star$ denotes the convex dual of a convex function $f$, $\\\\l_i$ is the loss for the example $i$, $n$ is the total number of examples and $\\\\lambda n$ is the L2 parameter.\\n\",\n        \"\\n\",\n        \"Following [1,2], we use a data partition $\\\\P_1,\\\\dots,\\\\P_K$ of $\\\\{1,2,\\\\dots,n\\\\}$ such that $\\\\P_k$ contains the examples on machine $k$.\\n\",\n        \"For an $n$-dimensional vector $h$, we denote by $h_{[k]}$ the $n$-dimensional vector restricted to the machine $k$: $(h_{[k]})_i = h_i$ if $i\\\\in\\\\P_k$ and $0$ otherwise.\\n\",\n        \"\\n\",\n        \"## CoCoA+ Local Solver\\n\",\n        \"\\n\",\n        \"The local subproblem on machine $k$ is [1, 2]\\n\",\n        \"\\n\",\n        \"$$ \\\\max_{\\\\d_{[k]}} \\\\mathcal{G}^{\\\\sigma}_k (\\\\d_{[k]}) $$\\n\",\n        \"\\n\",\n        \"with\\n\",\n        \"\\n\",\n        \"$$\\n\",\n        \"\\\\mathcal{G}^{\\\\sigma}_k (\\\\d_{[k]}) =\\n\",\n        \"-\\\\frac{1}{n} \\\\sum_{i\\\\in\\\\P_k}\\\\l_i^\\\\star(-\\\\a_i-(\\\\d_{[k]})_i) -\\\\frac{1}{n} w^T X\\n\",\n        \"\\\\d_{[k]}- \\\\frac{\\\\lambda}{2}\\\\sigma \\\\left\\\\| \\\\frac{1}{\\\\lambda n} X \\\\d_{[k]}\\n\",\n        \"\\\\right\\\\|^2 $$\\n\",\n        \"\\n\",\n        \"$\\\\sigma$ is a parameter the measures the difficulty of the data partition. CoCoA+ makes the choice $ \\\\sigma = K $\\n\",\n        \"\\n\",\n        \"This decision is motivated in [2] and shown to be more efficient than the previous CoCoA choice ($\\\\sigma = 1$).\\n\",\n        \"\\n\",\n        \"For one example, the problem is simply\\n\",\n        \"\\n\",\n        \"$$ \\\\max_{\\\\d} \\\\left\\\\{ D_i(\\\\d) = -\\\\l_i^\\\\star(-(\\\\a_i+\\\\d)) - \\\\bar{y}_i \\\\d - \\\\frac{A}{2} \\\\d^2 \\\\right\\\\} $$\\n\",\n        \"\\n\",\n        \"where we have defined $A=\\\\sigma X_i^2/(\\\\lambda n)$ and $ \\\\bar{y}_i = w^T X_i$\\n\",\n        \"\\n\",\n        \"To take into account example weights, it suffices to replace $1/n$ by $s_i/S$ where $s_i$ is the weight of the i-th example and $S=\\\\sum s_i$. For our problem, this will only change $A$ to $\\\\sigma X_i^2s_i/(\\\\lambda S)$.\\n\",\n        \"\\n\",\n        \"### Hinge Loss\\n\",\n        \"\\n\",\n        \"Hinge loss is given by $ \\\\l_i(u) = \\\\max(0,1-y u) $. Its convex dual is $\\\\l_i^\\\\star(-a) = -a y$ with the constraint $ a y\\\\in [0,1] $.\\n\",\n        \"\\n\",\n        \"The solution for the update is given explicitly in [3]. To derive the CoCoA+ formulation, we replace $\\\\lambda$ by $\\\\frac{\\\\lambda}{\\\\sigma}$. This gives\\n\",\n        \"\\n\",\n        \"$$ \\\\d = \\\\frac{y - \\\\bar{y}}{A} $$\\n\",\n        \"\\n\",\n        \"with the restriction that $y(\\\\a+\\\\d)\\\\in(0,1)$.\\n\",\n        \"\\n\",\n        \"### Smooth Hinge Loss\\n\",\n        \"\\n\",\n        \"Smooth hinge loss is given by\\n\",\n        \"\\n\",\n        \"$$ \\\\l_i(u) =\\n\",\n        \"\\\\begin{cases}\\n\",\n        \"0 \\\\:\\\\:\\\\: \\u0026 y_i u \\\\geq 1\\\\\\\\\\n\",\n        \"1-y_i u -\\\\gamma/2 \\\\:\\\\:\\\\:\\u0026 y_i u \\\\leq1-\\\\gamma \\\\\\\\\\n\",\n        \"\\\\frac{(1-y_i u)^2}{2\\\\gamma} \\u0026 \\\\text{otherwise}\\n\",\n        \"\\\\end{cases} $$\\n\",\n        \"\\n\",\n        \"The optimal $\\\\d$ is computed to be\\n\",\n        \"\\n\",\n        \"$$\\\\d = \\\\frac{y-\\\\bar{y}-\\\\gamma\\\\a}{A+\\\\gamma} $$\\n\",\n        \"\\n\",\n        \"with the restriction that $y(\\\\a+\\\\d)\\\\in(0,1)$. We see that we recover standard hinge update for $\\\\gamma = 0$. The details of the computation can be found in Appendix.\\n\",\n        \"\\n\",\n        \"### Squared Loss\\n\",\n        \"\\n\",\n        \"Squared loss is $ \\\\l_i(u) = \\\\frac{1}{2}(u-y)^2 $ with dual $ \\\\l_i^\\\\star(v) =\\\\frac{1}{2}v^2+y v$.\\n\",\n        \"\\n\",\n        \"The closed form solution for squared loss is given in [4]. By replacing again $\\\\lambda$ by $\\\\frac{\\\\lambda}{\\\\sigma}$ we obtain\\n\",\n        \"\\n\",\n        \"$$ \\\\d = -\\\\frac{\\\\a + w^T X_i - y}{1 + \\\\frac{\\\\sigma X_i^2}{2 \\\\lambda n}} $$\\n\",\n        \"\\n\",\n        \"### Logistic loss\\n\",\n        \"\\n\",\n        \"Logistic loss is $ \\\\l_i(u) = \\\\log (1+e^{-uy_i}) $ and its dual is\\n\",\n        \"\\n\",\n        \"$$ \\\\l_i^\\\\star(v) = -vy_i\\\\log(-vy_i) + (1+vy_i)\\n\",\n        \"\\\\log(1+vy_i) $$\\n\",\n        \"\\n\",\n        \"The label $y_i$ is $\\\\pm 1$ and the dual loss is only defined for $ -y_i v\\\\in (0,1) $. We then have the constraint\\n\",\n        \"\\n\",\n        \"$$  y_i (\\\\a+\\\\d) \\\\in (0,1) $$\\n\",\n        \"\\n\",\n        \"The problem of finding the maximum of $ D(\\\\d) $ can be reformulated as the problem of finding the unique zero of its derivative. Newton method works well for finding the zero of $ D'(\\\\d) $ but can be a bit unstable due to the constraint requiring $y_i(\\\\a+\\\\d)$ be in the range $(0,1)$ (more on this below).\\n\",\n        \"\\n\",\n        \"To avoid this problem, we make the following change of variable\\n\",\n        \"\\n\",\n        \"$$ y(\\\\a+\\\\d) = \\\\frac{1}{2}(1+\\\\tanh x) $$\\n\",\n        \"\\n\",\n        \"This enforces the constraint and is well suited because the objective derivative\\n\",\n        \"has the following simple form:\\n\",\n        \"\\n\",\n        \"$$ D' = H(x) = -2y x - \\\\bar{y} + A\\\\a -\\\\frac{A}{2y}(1+\\\\tanh x) $$\\n\",\n        \"\\n\",\n        \"with derivative\\n\",\n        \"\\n\",\n        \"$$ H'(x) = -2y - \\\\frac{A}{2y}(1-\\\\tanh^2 x) $$\\n\",\n        \"\\n\",\n        \"This function is always positive or always negative so that $H$ is strictly monotonic.\\n\",\n        \"\\n\",\n        \"We can start Newton algorithm at $x_0=0$ which corresponds to $ y(\\\\a+\\\\d) = 0.5 $. A Newton step is given by\\n\",\n        \"\\n\",\n        \"$$x_{k+1} = x_k - \\\\frac{H(x_k)}{H'(x_k)} $$\\n\",\n        \"\\n\",\n        \"The convergence is very fast with the modified function and 5 Newton steps should be largely enough.\\n\",\n        \"\\n\",\n        \"#### Proof of convergence\\n\",\n        \"\\n\",\n        \"The second derivative of $H$\\n\",\n        \"\\n\",\n        \"$$ H''(x) = \\\\frac{A}{y} \\\\tanh x (1-\\\\tanh^2 x) $$\\n\",\n        \"\\n\",\n        \"is bounded and quadratic convergence should be guaranteed if we are close enough to the solution (see proof [here](https://en.wikipedia.org/wiki/Newton%27s_method#Proof_of_quadratic_convergence_for_Newton.27s_iterative_method)).\\n\",\n        \"\\n\",\n        \"However we can't really know if we are close to the zero. To prove the convergence in any cases, we can use Kantovitch Theorem (reviewed in [5]). The sufficient condition to have convergence is that we start at a point $ x_0 $ such that\\n\",\n        \"\\n\",\n        \"$$\\n\",\n        \"\\\\left|\\\\frac{4A H(x_0)}{H'(x_0)^2} \\\\right|\\\\leq 1\\n\",\n        \"$$\\n\",\n        \"\\n\",\n        \"If $ A$ is not small, the starting point $x_0 = 0$ doesn't satisfy this condition and we may solve the above inequality to find a starting point which does.\\n\",\n        \"\\n\",\n        \"However, in practice, convergence with $x_0 = 0$ always happens (tested for a sample of generic values for the parameters).\\n\",\n        \"\\n\",\n        \"### Poisson log loss\\n\",\n        \"\\n\",\n        \"Poisson log loss is defined as $ \\\\l(u) = e^u - uy $ for label $y \\\\geq 0.$ Its dual is\\n\",\n        \"\\n\",\n        \"$$ \\\\l^\\\\star(v) = (y+v) (\\\\log(y+v) - 1) $$\\n\",\n        \"\\n\",\n        \"and is only defined for $ y+v \\u003e 0 $. We then have the constraint\\n\",\n        \"\\n\",\n        \"$$  y \\u003e \\\\a+\\\\d. $$\\n\",\n        \"\\n\",\n        \"The dual is\\n\",\n        \"\\n\",\n        \"$$ D(\\\\d) = -(y-\\\\a-\\\\d) (\\\\log(y-\\\\a-\\\\d) - 1) - \\\\bar{y} \\\\d - \\\\frac{A}{2} \\\\d^2 $$\\n\",\n        \"\\n\",\n        \"and its derivative is,\\n\",\n        \"\\n\",\n        \"$$ D'(\\\\d) = \\\\log(y-\\\\a-\\\\d) - \\\\bar{y} - A\\\\d $$\\n\",\n        \"\\n\",\n        \"Similar to the logistic loss, we perform a change of variable to handle the constraint on $ \\\\d $\\n\",\n        \"\\n\",\n        \"$$ y - (\\\\a+\\\\d) = e^x $$\\n\",\n        \"\\n\",\n        \"After this change of variable, the goal is to find the zero of this function\\n\",\n        \"\\n\",\n        \"$$ H(x) = x - \\\\bar{y} -A(y-\\\\a-e^x) $$\\n\",\n        \"\\n\",\n        \"whose first derivative is\\n\",\n        \"\\n\",\n        \"$$ H'(x) = 1+Ae^x $$\\n\",\n        \"\\n\",\n        \"Since this function is always positive, $H$ is increasing and has a unique zero.\\n\",\n        \"\\n\",\n        \"We can start Newton algorithm at $\\\\d=0$ which corresponds to $ x =\\\\log(y-\\\\a)$. As before the Newton step is given by\\n\",\n        \"\\n\",\n        \"$$x_{k+1} = x_k - \\\\frac{H(x_k)}{H'(x_k)}. $$\\n\",\n        \"\\n\",\n        \"### References\\n\",\n        \"\\n\",\n        \"[1] C. Ma et al., [Adding vs. Averaging in Distributed Primal-Dual Optimization](https://arxiv.org/pdf/1502.03508.pdf), 2015.\\n\",\n        \"\\n\",\n        \"[2] C. Ma et al., [Distributed Optimization with Arbitrary Local Solvers](https://arxiv.org/pdf/1512.04039.pdf), 2015.\\n\",\n        \"\\n\",\n        \"[3] S. Shalev-Shwartz, T. Zhang, [Stochastic Dual Coordinate Ascent Methods for Regularized Loss Minimization](http://www.jmlr.org/papers/volume14/shalev-shwartz13a/shalev-shwartz13a.pdf), 2013.\\n\",\n        \"\\n\",\n        \"[4] S. Shalev-Shwartz, T. Zhang, [Accelerated Proximal Stochastic Dual Coordinate Ascent for Regularized Loss Minimization](https://arxiv.org/pdf/1309.2375.pdf), 2013.\\n\",\n        \"\\n\",\n        \"[5] A. Galantai, [The theory of Newton’s method](https://www.sciencedirect.com/science/article/pii/S0377042700004350), 2000.\\n\",\n        \"\\n\",\n        \"## Appendix\\n\",\n        \"\\n\",\n        \"#### Dual computation for smooth hinge loss\\n\",\n        \"\\n\",\n        \"We want to compute $\\\\l^\\\\star(v) = \\\\max_u [ uv-\\\\l(u) ] $ where $\\\\l$ is smooth hinge loss. We thus have to solve $v=\\\\l'(u)$. The derivative of smooth hinge loss is given by\\n\",\n        \"\\n\",\n        \"$$ \\\\l'(u) =\\n\",\n        \"\\\\begin{cases}\\n\",\n        \"0 \\\\:\\\\:\\\\: \\u0026 y_i u \\\\geq 1\\\\\\\\\\n\",\n        \"-y \\\\:\\\\:\\\\:\\u0026 y_i u \\\\leq1-\\\\gamma \\\\\\\\\\n\",\n        \"\\\\frac{u-y}{\\\\gamma} \\u0026 \\\\text{otherwise}\\n\",\n        \"\\\\end{cases} $$\\n\",\n        \"\\n\",\n        \"By solving for $v$, we find the dual of smooth hinge loss as\\n\",\n        \"\\n\",\n        \"$$ \\\\l^\\\\star(v) = yv + \\\\frac{\\\\gamma}{2}v^2 $$\\n\",\n        \"\\n\",\n        \"with the restriction $ yv \\\\in (0,1) $.\\n\",\n        \"\\n\",\n        \"Now, we can now minimize the dual objective with respect to $\\\\d$\\n\",\n        \"\\n\",\n        \"$$ D(\\\\a+\\\\d) = -\\\\l^\\\\star(-\\\\a-\\\\d)-\\\\bar{y}\\\\d-\\\\frac{A}{2} \\\\d^2 $$\\n\",\n        \"\\n\",\n        \"which gives the expected result\\n\",\n        \"\\n\",\n        \"$$\\\\d = \\\\frac{y-\\\\bar{y}-\\\\gamma\\\\a}{A+\\\\gamma} $$\\n\",\n        \"\\n\",\n        \"with the constraint $ y(\\\\a+\\\\d) \\\\in (0,1)$.\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"colab\": {\n      \"collapsed_sections\": [],\n      \"name\": \"SDCA.ipynb\",\n      \"provenance\": [\n        {\n          \"file_id\": \"1dYYnnjfGC6CIpfwy__EXL81Xnl26wiKM\",\n          \"timestamp\": 1539909050558\n        }\n      ],\n      \"version\": \"0.3.2\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 0\n}\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_optimizer/python/sdca_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for canned linear estimators with the SDCA optimizer.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.canned import linear\n\n\nclass SDCAClassifierTest(tf.test.TestCase):\n\n  def testRealValuedFeatures(self):\n    \"\"\"Tests LinearClassifier with LinearSDCA and real valued features.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id': tf.constant(['1', '2']),\n          'maintenance_cost': tf.constant([[500.0], [200.0]]),\n          'sq_footage': tf.constant([[800.0], [600.0]]),\n          'weights': tf.constant([[1.0], [1.0]])\n      }, tf.constant([[0], [1]])\n\n    maintenance_cost = tf.feature_column.numeric_column('maintenance_cost')\n    sq_footage = tf.feature_column.numeric_column('sq_footage')\n    optimizer = linear.LinearSDCA(example_id_column='example_id')\n    classifier = linear.LinearClassifierV2(\n        feature_columns=[maintenance_cost, sq_footage],\n        weight_column='weights',\n        optimizer=optimizer)\n    classifier.train(input_fn=input_fn, steps=100)\n    loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.2)\n\n  def testRealValuedFeatureWithHigherDimension(self):\n    \"\"\"Tests LinearSDCA with real valued features of higher dimension.\"\"\"\n\n    # input_fn is identical to the one in testRealValuedFeatures\n    # where 2 1-dimensional dense features have been replaced by 1 2-dimensional\n    # feature.\n    def input_fn():\n      return {\n          'example_id': tf.constant(['1', '2']),\n          'dense_feature': tf.constant([[500.0, 800.0], [200.0, 600.0]])\n      }, tf.constant([[0], [1]])\n\n    dense_feature = tf.feature_column.numeric_column('dense_feature', shape=2)\n    optimizer = linear.LinearSDCA(example_id_column='example_id')\n    classifier = linear.LinearClassifierV2(\n        feature_columns=[dense_feature], optimizer=optimizer)\n    classifier.train(input_fn=input_fn, steps=100)\n    loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.2)\n\n  def testBucketizedFeatures(self):\n    \"\"\"Tests LinearClassifier with LinearSDCA and bucketized features.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id': tf.constant(['1', '2', '3']),\n          'price': tf.constant([[600.0], [1000.0], [400.0]]),\n          'sq_footage': tf.constant([[1000.0], [600.0], [700.0]]),\n          'weights': tf.constant([[1.0], [1.0], [1.0]])\n      }, tf.constant([[1], [0], [1]])\n\n    price_bucket = tf.feature_column.bucketized_column(\n        tf.feature_column.numeric_column('price'), boundaries=[500.0, 700.0])\n    sq_footage_bucket = tf.feature_column.bucketized_column(\n        tf.feature_column.numeric_column('sq_footage'), boundaries=[650.0])\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.01)\n    classifier = linear.LinearClassifierV2(\n        feature_columns=[price_bucket, sq_footage_bucket],\n        weight_column='weights',\n        optimizer=optimizer)\n    classifier.train(input_fn=input_fn, steps=100)\n    loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.2)\n\n  def testSparseFeatures(self):\n    \"\"\"Tests LinearClassifier with LinearSDCA and sparse features.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id':\n              tf.constant(['1', '2', '3']),\n          'country':\n              tf.sparse.SparseTensor(\n                  values=['IT', 'US', 'GB'],\n                  indices=[[0, 0], [1, 0], [2, 0]],\n                  dense_shape=[3, 5]),\n          'weights':\n              tf.constant([[1.0], [1.0], [1.0]])\n      }, tf.constant([[1], [0], [1]])\n\n    country = tf.feature_column.categorical_column_with_hash_bucket(\n        'country', hash_bucket_size=5)\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.01)\n    classifier = linear.LinearClassifierV2(\n        feature_columns=[country], weight_column='weights', optimizer=optimizer)\n    classifier.train(input_fn=input_fn, steps=100)\n    loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.2)\n\n  def testWeightedSparseFeatures(self):\n    \"\"\"LinearClassifier with LinearSDCA and weighted sparse features.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id':\n              tf.constant(['1', '2', '3']),\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[2., 3., 1.],\n                  indices=[[0, 0], [1, 0], [2, 0]],\n                  dense_shape=[3, 5]),\n          'country':\n              tf.sparse.SparseTensor(\n                  values=['IT', 'US', 'GB'],\n                  indices=[[0, 0], [1, 0], [2, 0]],\n                  dense_shape=[3, 5])\n      }, tf.constant([[1], [0], [1]])\n\n    country = tf.feature_column.categorical_column_with_hash_bucket(\n        'country', hash_bucket_size=5)\n    country_weighted_by_price = (\n        tf.feature_column.weighted_categorical_column(country, 'price'))\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.01)\n    classifier = linear.LinearClassifierV2(\n        feature_columns=[country_weighted_by_price], optimizer=optimizer)\n    classifier.train(input_fn=input_fn, steps=100)\n    loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.2)\n\n  def testWeightedSparseFeaturesOOVWithNoOOVBuckets(self):\n    \"\"\"LinearClassifier with LinearSDCA with OOV features (-1 IDs).\"\"\"\n\n    def input_fn():\n      return {\n          'example_id':\n              tf.constant(['1', '2', '3']),\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[2., 3., 1.],\n                  indices=[[0, 0], [1, 0], [2, 0]],\n                  dense_shape=[3, 5]),\n          'country':\n              tf.sparse.SparseTensor(\n                  # 'GB' is out of the vocabulary.\n                  values=['IT', 'US', 'GB'],\n                  indices=[[0, 0], [1, 0], [2, 0]],\n                  dense_shape=[3, 5])\n      }, tf.constant([[1], [0], [1]])\n\n    country = tf.feature_column.categorical_column_with_vocabulary_list(\n        'country', vocabulary_list=['US', 'CA', 'MK', 'IT', 'CN'])\n    country_weighted_by_price = (\n        tf.feature_column.weighted_categorical_column(country, 'price'))\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.01)\n    classifier = linear.LinearClassifierV2(\n        feature_columns=[country_weighted_by_price], optimizer=optimizer)\n    classifier.train(input_fn=input_fn, steps=100)\n    loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.2)\n\n  def testCrossedFeatures(self):\n    \"\"\"Tests LinearClassifier with LinearSDCA and crossed features.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id':\n              tf.constant(['1', '2', '3']),\n          'language':\n              tf.sparse.SparseTensor(\n                  values=['english', 'italian', 'spanish'],\n                  indices=[[0, 0], [1, 0], [2, 0]],\n                  dense_shape=[3, 1]),\n          'country':\n              tf.sparse.SparseTensor(\n                  values=['US', 'IT', 'MX'],\n                  indices=[[0, 0], [1, 0], [2, 0]],\n                  dense_shape=[3, 1])\n      }, tf.constant([[0], [0], [1]])\n\n    country_language = tf.feature_column.crossed_column(['language', 'country'],\n                                                        hash_bucket_size=100)\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.01)\n    classifier = linear.LinearClassifierV2(\n        feature_columns=[country_language], optimizer=optimizer)\n    classifier.train(input_fn=input_fn, steps=100)\n    loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.2)\n\n  def testMixedFeatures(self):\n    \"\"\"Tests LinearClassifier with LinearSDCA and a mix of features.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id':\n              tf.constant(['1', '2', '3']),\n          'price':\n              tf.constant([[0.6], [0.8], [0.3]]),\n          'sq_footage':\n              tf.constant([[900.0], [700.0], [600.0]]),\n          'country':\n              tf.sparse.SparseTensor(\n                  values=['IT', 'US', 'GB'],\n                  indices=[[0, 0], [1, 3], [2, 1]],\n                  dense_shape=[3, 5]),\n          'weights':\n              tf.constant([[3.0], [1.0], [1.0]])\n      }, tf.constant([[1], [0], [1]])\n\n    price = tf.feature_column.numeric_column('price')\n    sq_footage_bucket = tf.feature_column.bucketized_column(\n        tf.feature_column.numeric_column('sq_footage'),\n        boundaries=[650.0, 800.0])\n    country = tf.feature_column.categorical_column_with_hash_bucket(\n        'country', hash_bucket_size=5)\n    sq_footage_country = tf.feature_column.crossed_column(\n        [sq_footage_bucket, 'country'], hash_bucket_size=10)\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.01)\n    classifier = linear.LinearClassifierV2(\n        feature_columns=[price, sq_footage_bucket, country, sq_footage_country],\n        weight_column='weights',\n        optimizer=optimizer)\n    classifier.train(input_fn=input_fn, steps=100)\n    loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.2)\n\n  def testPartitionedVariables(self):\n    \"\"\"Tests LinearClassifier with LinearSDCA with partitioned variables.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id':\n              tf.constant(['1', '2', '3']),\n          'price':\n              tf.constant([[0.6], [0.8], [0.3]]),\n          'sq_footage':\n              tf.constant([[900.0], [700.0], [600.0]]),\n          'country':\n              tf.sparse.SparseTensor(\n                  values=['IT', 'US', 'GB'],\n                  indices=[[0, 0], [1, 3], [2, 1]],\n                  dense_shape=[3, 5]),\n          'weights':\n              tf.constant([[3.0], [1.0], [1.0]])\n      }, tf.constant([[1], [0], [1]])\n\n    price = tf.feature_column.numeric_column('price')\n    sq_footage_bucket = tf.feature_column.bucketized_column(\n        tf.feature_column.numeric_column('sq_footage'),\n        boundaries=[650.0, 800.0])\n    country = tf.feature_column.categorical_column_with_hash_bucket(\n        'country', hash_bucket_size=5)\n    sq_footage_country = tf.feature_column.crossed_column(\n        [sq_footage_bucket, 'country'], hash_bucket_size=10)\n\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.01)\n\n    classifier = linear.LinearClassifier(\n        feature_columns=[price, sq_footage_bucket, country, sq_footage_country],\n        weight_column='weights',\n        partitioner=tf.compat.v1.fixed_size_partitioner(num_shards=2, axis=0),\n        optimizer=optimizer)\n    classifier.train(input_fn=input_fn, steps=100)\n    loss = classifier.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.2)\n\n\nclass SDCARegressorTest(tf.test.TestCase):\n\n  def testRealValuedLinearFeatures(self):\n    \"\"\"Tests LinearRegressor with LinearSDCA and real valued features.\"\"\"\n    x = [[1.2, 2.0, -1.5], [-2.0, 3.0, -0.5], [1.0, -0.5, 4.0]]\n    weights = [[3.0], [-1.2], [0.5]]\n    y = np.dot(x, weights)\n\n    def input_fn():\n      return {\n          'example_id': tf.constant(['1', '2', '3']),\n          'x': tf.constant(x),\n          'weights': tf.constant([[10.0], [10.0], [10.0]])\n      }, tf.constant(y)\n\n    x_column = tf.feature_column.numeric_column('x', shape=3)\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.1)\n    regressor = linear.LinearRegressorV2(\n        feature_columns=[x_column],\n        weight_column='weights',\n        optimizer=optimizer)\n    regressor.train(input_fn=input_fn, steps=20)\n    loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.01)\n    self.assertIn('linear/linear_model/x/weights',\n                  regressor.get_variable_names())\n    regressor_weights = regressor.get_variable_value(\n        'linear/linear_model/x/weights')\n    self.assertAllClose([w[0] for w in weights],\n                        regressor_weights.flatten(),\n                        rtol=0.1)\n\n  def testMixedFeaturesArbitraryWeights(self):\n    \"\"\"Tests LinearRegressor with LinearSDCA and a mix of features.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id':\n              tf.constant(['1', '2', '3']),\n          'price':\n              tf.constant([0.6, 0.8, 0.3]),\n          'sq_footage':\n              tf.constant([[900.0], [700.0], [600.0]]),\n          'country':\n              tf.sparse.SparseTensor(\n                  values=['IT', 'US', 'GB'],\n                  indices=[[0, 0], [1, 3], [2, 1]],\n                  dense_shape=[3, 5]),\n          'weights':\n              tf.constant([[3.0], [5.0], [7.0]])\n      }, tf.constant([[1.55], [-1.25], [-3.0]])\n\n    price = tf.feature_column.numeric_column('price')\n    sq_footage_bucket = tf.feature_column.bucketized_column(\n        tf.feature_column.numeric_column('sq_footage'),\n        boundaries=[650.0, 800.0])\n    country = tf.feature_column.categorical_column_with_hash_bucket(\n        'country', hash_bucket_size=5)\n    sq_footage_country = tf.feature_column.crossed_column(\n        [sq_footage_bucket, 'country'], hash_bucket_size=10)\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.1)\n    regressor = linear.LinearRegressorV2(\n        feature_columns=[price, sq_footage_bucket, country, sq_footage_country],\n        weight_column='weights',\n        optimizer=optimizer)\n    regressor.train(input_fn=input_fn, steps=20)\n    loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.05)\n\n  def testPartitionedVariables(self):\n    \"\"\"Tests LinearRegressor with LinearSDCA with partitioned variables.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id':\n              tf.constant(['1', '2', '3']),\n          'price':\n              tf.constant([0.6, 0.8, 0.3]),\n          'sq_footage':\n              tf.constant([[900.0], [700.0], [600.0]]),\n          'country':\n              tf.sparse.SparseTensor(\n                  values=['IT', 'US', 'GB'],\n                  indices=[[0, 0], [1, 3], [2, 1]],\n                  dense_shape=[3, 5]),\n          'weights':\n              tf.constant([[3.0], [5.0], [7.0]])\n      }, tf.constant([[1.55], [-1.25], [-3.0]])\n\n    price = tf.feature_column.numeric_column('price')\n    sq_footage_bucket = tf.feature_column.bucketized_column(\n        tf.feature_column.numeric_column('sq_footage'),\n        boundaries=[650.0, 800.0])\n    country = tf.feature_column.categorical_column_with_hash_bucket(\n        'country', hash_bucket_size=5)\n    sq_footage_country = tf.feature_column.crossed_column(\n        [sq_footage_bucket, 'country'], hash_bucket_size=10)\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.1)\n\n    regressor = linear.LinearRegressor(\n        feature_columns=[price, sq_footage_bucket, country, sq_footage_country],\n        weight_column='weights',\n        partitioner=tf.compat.v1.fixed_size_partitioner(num_shards=2, axis=0),\n        optimizer=optimizer)\n    regressor.train(input_fn=input_fn, steps=20)\n    loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']\n    self.assertLess(loss, 0.05)\n\n  def testSparseFeaturesWithL1Reg(self):\n    \"\"\"Tests LinearRegressor with LinearSDCA and sparse features.\"\"\"\n\n    def input_fn():\n      return {\n          'example_id':\n              tf.constant(['1', '2', '3']),\n          'price':\n              tf.constant([[0.4], [0.6], [0.3]]),\n          'country':\n              tf.sparse.SparseTensor(\n                  values=['IT', 'US', 'GB'],\n                  indices=[[0, 0], [1, 3], [2, 1]],\n                  dense_shape=[3, 5]),\n          'weights':\n              tf.constant([[10.0], [10.0], [10.0]])\n      }, tf.constant([[1.4], [-0.8], [2.6]])\n\n    price = tf.feature_column.numeric_column('price')\n    country = tf.feature_column.categorical_column_with_hash_bucket(\n        'country', hash_bucket_size=5)\n    # Regressor with no L1 regularization.\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.1)\n    regressor = linear.LinearRegressorV2(\n        feature_columns=[price, country],\n        weight_column='weights',\n        optimizer=optimizer)\n    regressor.train(input_fn=input_fn, steps=20)\n    no_l1_reg_loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']\n    variable_names = regressor.get_variable_names()\n    self.assertIn('linear/linear_model/price/weights', variable_names)\n    self.assertIn('linear/linear_model/country/weights', variable_names)\n    no_l1_reg_weights = {\n        'linear/linear_model/price/weights':\n            regressor.get_variable_value('linear/linear_model/price/weights'),\n        'linear/linear_model/country/weights':\n            regressor.get_variable_value('linear/linear_model/country/weights'),\n    }\n\n    # Regressor with L1 regularization.\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id',\n        symmetric_l1_regularization=1.0,\n        symmetric_l2_regularization=0.1)\n    regressor = linear.LinearRegressorV2(\n        feature_columns=[price, country],\n        weight_column='weights',\n        optimizer=optimizer)\n    regressor.train(input_fn=input_fn, steps=20)\n    l1_reg_loss = regressor.evaluate(input_fn=input_fn, steps=1)['loss']\n    l1_reg_weights = {\n        'linear/linear_model/price/weights':\n            regressor.get_variable_value('linear/linear_model/price/weights'),\n        'linear/linear_model/country/weights':\n            regressor.get_variable_value('linear/linear_model/country/weights'),\n    }\n\n    # Unregularized loss is lower when there is no L1 regularization.\n    self.assertLess(no_l1_reg_loss, l1_reg_loss)\n    self.assertLess(no_l1_reg_loss, 0.05)\n\n    # But weights returned by the regressor with L1 regularization have smaller\n    # L1 norm.\n    l1_reg_weights_norm, no_l1_reg_weights_norm = 0.0, 0.0\n    for var_name in sorted(l1_reg_weights):\n      l1_reg_weights_norm += sum(\n          np.absolute(l1_reg_weights[var_name].flatten()))\n      no_l1_reg_weights_norm += sum(\n          np.absolute(no_l1_reg_weights[var_name].flatten()))\n      print('Var name: %s, value: %s' %\n            (var_name, no_l1_reg_weights[var_name].flatten()))\n    self.assertLess(l1_reg_weights_norm, no_l1_reg_weights_norm)\n\n  def testBiasOnly(self):\n    \"\"\"Tests LinearRegressor with LinearSDCA and validates bias weight.\"\"\"\n\n    def input_fn():\n      \"\"\"Testing the bias weight when it's the only feature present.\n\n      All of the instances in this input only have the bias feature, and a\n      1/4 of the labels are positive. This means that the expected weight for\n      the bias should be close to the average prediction, i.e 0.25.\n      Returns:\n        Training data for the test.\n      \"\"\"\n      num_examples = 40\n      return {\n          'example_id': tf.constant([str(x + 1) for x in range(num_examples)]),\n          # place_holder is an empty column which is always 0 (absent), because\n          # LinearClassifier requires at least one column.\n          'place_holder': tf.constant([[0.0]] * num_examples),\n      }, tf.constant([1 if i % 4 == 0 else 0 for i in range(num_examples)])\n\n    place_holder = tf.feature_column.numeric_column('place_holder')\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.1)\n    regressor = linear.LinearRegressorV2(\n        feature_columns=[place_holder], optimizer=optimizer)\n    regressor.train(input_fn=input_fn, steps=100)\n    self.assertNear(\n        regressor.get_variable_value('linear/linear_model/bias_weights')[0],\n        0.25,\n        err=0.1)\n\n  def testBiasAndOtherColumns(self):\n    \"\"\"Tests LinearRegressor with LinearSDCA and validates bias weight.\"\"\"\n\n    def input_fn():\n      \"\"\"Testing the bias weight when there are other features present.\n\n      1/2 of the instances in this input have feature 'a', the rest have\n      feature 'b', and we expect the bias to be added to each instance as well.\n      0.4 of all instances that have feature 'a' are positive, and 0.2 of all\n      instances that have feature 'b' are positive. The labels in the dataset\n      are ordered to appear shuffled since SDCA expects shuffled data, and\n      converges faster with this pseudo-random ordering.\n      If the bias was not regularized we would expect the weights to be:\n      bias: 0.3\n      a: 0.1\n      b: -0.1\n      Bu with bias regularization the optimal values are:\n      bias: 0.2\n      a: 0.2\n      b: 0.0\n      Returns:\n        The test dataset.\n      \"\"\"\n      num_examples = 200\n      half = int(num_examples / 2)\n      return {\n          'example_id': tf.constant([str(x + 1) for x in range(num_examples)]),\n          'a': tf.constant([[1]] * int(half) + [[0]] * int(half)),\n          'b': tf.constant([[0]] * int(half) + [[1]] * int(half)),\n      }, tf.constant([[x]\n                      for x in [1, 0, 0, 1, 1, 0, 0, 0, 1, 0] * int(half / 10) +\n                      [0, 1, 0, 0, 0, 0, 0, 0, 1, 0] * int(half / 10)])\n\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.1)\n    regressor = linear.LinearRegressorV2(\n        feature_columns=[\n            tf.feature_column.numeric_column('a'),\n            tf.feature_column.numeric_column('b')\n        ],\n        optimizer=optimizer)\n\n    regressor.train(input_fn=input_fn, steps=200)\n\n    variable_names = regressor.get_variable_names()\n    self.assertIn('linear/linear_model/bias_weights', variable_names)\n    self.assertIn('linear/linear_model/a/weights', variable_names)\n    self.assertIn('linear/linear_model/b/weights', variable_names)\n    # TODO(b/29339026): Change the expected results to expect a centered bias.\n    self.assertNear(\n        regressor.get_variable_value('linear/linear_model/bias_weights')[0],\n        0.2,\n        err=0.05)\n    self.assertNear(\n        regressor.get_variable_value('linear/linear_model/a/weights')[0],\n        0.2,\n        err=0.05)\n    self.assertNear(\n        regressor.get_variable_value('linear/linear_model/b/weights')[0],\n        0.0,\n        err=0.05)\n\n  def testBiasAndOtherColumnsFabricatedCentered(self):\n    \"\"\"Tests LinearRegressor with LinearSDCA and validates bias weight.\"\"\"\n\n    def input_fn():\n      \"\"\"Testing the bias weight when there are other features present.\n\n      1/2 of the instances in this input have feature 'a', the rest have\n      feature 'b', and we expect the bias to be added to each instance as well.\n      0.1 of all instances that have feature 'a' have a label of 1, and 0.1 of\n      all instances that have feature 'b' have a label of -1.\n      We can expect the weights to be:\n      bias: 0.0\n      a: 0.1\n      b: -0.1\n      Returns:\n        The test dataset.\n      \"\"\"\n      num_examples = 200\n      half = int(num_examples / 2)\n      return {\n          'example_id': tf.constant([str(x + 1) for x in range(num_examples)]),\n          'a': tf.constant([[1]] * int(half) + [[0]] * int(half)),\n          'b': tf.constant([[0]] * int(half) + [[1]] * int(half)),\n      }, tf.constant([[1 if x % 10 == 0 else 0] for x in range(half)] +\n                     [[-1 if x % 10 == 0 else 0] for x in range(half)])\n\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id', symmetric_l2_regularization=0.1)\n    regressor = linear.LinearRegressorV2(\n        feature_columns=[\n            tf.feature_column.numeric_column('a'),\n            tf.feature_column.numeric_column('b')\n        ],\n        optimizer=optimizer)\n\n    regressor.train(input_fn=input_fn, steps=100)\n\n    variable_names = regressor.get_variable_names()\n    self.assertIn('linear/linear_model/bias_weights', variable_names)\n    self.assertIn('linear/linear_model/a/weights', variable_names)\n    self.assertIn('linear/linear_model/b/weights', variable_names)\n    self.assertNear(\n        regressor.get_variable_value('linear/linear_model/bias_weights')[0],\n        0.0,\n        err=0.05)\n    self.assertNear(\n        regressor.get_variable_value('linear/linear_model/a/weights')[0],\n        0.1,\n        err=0.05)\n    self.assertNear(\n        regressor.get_variable_value('linear/linear_model/b/weights')[0],\n        -0.1,\n        err=0.05)\n\n  def testUnknownBatchSize(self):\n    \"\"\"Tests LinearRegressor with LinearSDCA and unknown batch size.\"\"\"\n\n    def input_fn():\n      # Similar to testBiasOnly but use placeholder_with_default in order to\n      # let the static batch size unspecified.\n      return {\n          'example_id':\n              tf.compat.v1.placeholder_with_default(\n                  tf.constant(['0', '1']), shape=[None]),\n          # always_zero is an empty column which is always 0 (absent), because\n          # LinearClassifier requires at least one column.\n          'always_zero':\n              tf.compat.v1.placeholder_with_default(\n                  tf.constant([[0.0]] * 2), shape=[None, 1]),\n      }, tf.compat.v1.placeholder_with_default(\n          tf.constant([0.0, 1.0]), shape=[None])\n\n    always_zero = tf.feature_column.numeric_column('always_zero')\n    optimizer = linear.LinearSDCA(\n        example_id_column='example_id',\n        symmetric_l2_regularization=0.1,\n        num_table_shards=3)\n    regressor = linear.LinearRegressorV2(\n        feature_columns=[always_zero], optimizer=optimizer)\n    regressor.train(input_fn=input_fn, steps=100)\n    self.assertNear(\n        regressor.get_variable_value('linear/linear_model/bias_weights')[0],\n        0.5,\n        err=0.1)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_optimizer/python/utils/sdca_ops.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Proximal stochastic dual coordinate ascent optimizer for linear models.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n\nfrom six.moves import range\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework.ops import internal_convert_to_tensor\nfrom tensorflow.python.framework.ops import name_scope\nfrom tensorflow.python.ops import gen_sdca_ops\nfrom tensorflow.python.ops import variables as var_ops\nfrom tensorflow.python.ops.nn import log_poisson_loss\nfrom tensorflow.python.ops.nn import sigmoid_cross_entropy_with_logits\nfrom tensorflow_estimator.python.estimator.canned.linear_optimizer.python.utils.sharded_mutable_dense_hashtable import _ShardedMutableDenseHashTable\n\n\nclass _SparseFeatureColumn(object):\n  \"\"\"Represents a sparse feature column.\n\n  This is meant to be a more efficient representation than tf.SparseFeature for\n  the purpose of SDCA optimization.\n  Contains three tensors representing a sparse feature column, they are\n  example indices (`int64`), feature indices (`int64`), and feature\n  values (`float`).\n  Feature weights are optional, and are treated as `1.0f` if missing.\n\n  For example, consider a batch of 4 examples, which contains the following\n  features in a particular `_SparseFeatureColumn`:\n\n  * Example 0: feature 5, value 1\n  * Example 1: feature 6, value 1 and feature 10, value 0.5\n  * Example 2: no features\n  * Example 3: two copies of feature 2, value 1\n\n  This _SparseFeatureColumn will be represented as follows:\n\n  ```\n   <0, 5,  1>\n   <1, 6,  1>\n   <1, 10, 0.5>\n   <3, 2,  1>\n   <3, 2,  1>\n  ```\n\n  For a batch of 2 examples below:\n\n  * Example 0: feature 5\n  * Example 1: feature 6\n\n  is represented by `_SparseFeatureColumn` as:\n\n  ```\n   <0, 5,  1>\n   <1, 6,  1>\n\n  ```\n\n  @@__init__\n  @@example_indices\n  @@feature_indices\n  @@feature_values\n  \"\"\"\n\n  def __init__(self, example_indices, feature_indices, feature_values):\n    \"\"\"Creates a `_SparseFeatureColumn` representation.\n\n    Args:\n      example_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts python\n        lists, or numpy arrays.\n      feature_indices: A 1-D int64 tensor of shape `[N]`. Also, accepts python\n        lists, or numpy arrays.\n      feature_values: An optional 1-D tensor float tensor of shape `[N]`. Also,\n        accepts python lists, or numpy arrays.\n\n    Returns:\n      A `_SparseFeatureColumn`\n    \"\"\"\n    with name_scope(None, 'SparseFeatureColumn',\n                    [example_indices, feature_indices]):\n      self._example_indices = internal_convert_to_tensor(\n          example_indices, name='example_indices', dtype=tf.dtypes.int64)\n      self._feature_indices = internal_convert_to_tensor(\n          feature_indices, name='feature_indices', dtype=tf.dtypes.int64)\n    self._feature_values = None\n    if feature_values is not None:\n      with name_scope(None, 'SparseFeatureColumn', [feature_values]):\n        self._feature_values = internal_convert_to_tensor(\n            feature_values, name='feature_values', dtype=tf.dtypes.float32)\n\n  @property\n  def example_indices(self):\n    \"\"\"The example indices represented as a dense tensor.\n\n    Returns:\n      A 1-D Tensor of int64 with shape `[N]`.\n    \"\"\"\n    return self._example_indices\n\n  @property\n  def feature_indices(self):\n    \"\"\"The feature indices represented as a dense tensor.\n\n    Returns:\n      A 1-D Tensor of int64 with shape `[N]`.\n    \"\"\"\n    return self._feature_indices\n\n  @property\n  def feature_values(self):\n    \"\"\"The feature values represented as a dense tensor.\n\n    Returns:\n      May return None, or a 1-D Tensor of float32 with shape `[N]`.\n    \"\"\"\n    return self._feature_values\n\n\nclass _SDCAModel(object):\n  \"\"\"Stochastic dual coordinate ascent solver for linear models.\n\n    Loss functions supported:\n\n     * Binary logistic loss\n     * Squared loss\n     * Hinge loss\n     * Smooth hinge loss\n     * Poisson log loss\n\n    ### Usage\n\n    ```python\n    # Create a solver with the desired parameters.\n    lr = _SDCAModel(examples, variables, options)\n    min_op = lr.minimize()\n    opt_op = lr.update_weights(min_op)\n\n    predictions = lr.predictions(examples)\n    # Primal loss + L1 loss + L2 loss.\n    regularized_loss = lr.regularized_loss(examples)\n    # Primal loss only\n    unregularized_loss = lr.unregularized_loss(examples)\n\n    examples: {\n      sparse_features: list of SparseFeatureColumn.\n      dense_features: list of dense tensors of type float32.\n      example_labels: a tensor of type float32 and shape [Num examples]\n      example_weights: a tensor of type float32 and shape [Num examples]\n      example_ids: a tensor of type string and shape [Num examples]\n    }\n    variables: {\n      sparse_features_weights: list of tensors of shape [vocab size]\n      dense_features_weights: list of tensors of shape [dense_feature_dimension]\n    }\n    options: {\n      symmetric_l1_regularization: 0.0\n      symmetric_l2_regularization: 1.0\n      loss_type: \"logistic_loss\"\n      num_loss_partitions: 1 (Optional, with default value of 1. Number of\n      partitions of the global loss function, 1 means single machine solver,\n      and >1 when we have more than one optimizer working concurrently.)\n      num_table_shards: 1 (Optional, with default value of 1. Number of shards\n      of the internal state table, typically set to match the number of\n      parameter servers for large data sets.\n    }\n    ```\n\n    In the training program you will just have to run the returned Op from\n    minimize().\n\n    ```python\n    # Execute opt_op and train for num_steps.\n    for _ in range(num_steps):\n      opt_op.run()\n\n    # You can also check for convergence by calling\n    lr.approximate_duality_gap()\n    ```\n  \"\"\"\n\n  def __init__(self, examples, variables, options):\n    \"\"\"Create a new sdca optimizer.\"\"\"\n\n    if not examples or not variables or not options:\n      raise ValueError('examples, variables and options must all be specified.')\n\n    supported_losses = ('logistic_loss', 'squared_loss', 'hinge_loss',\n                        'smooth_hinge_loss', 'poisson_loss')\n    if options['loss_type'] not in supported_losses:\n      raise ValueError('Unsupported loss_type: ', options['loss_type'])\n\n    self._assert_specified([\n        'example_labels', 'example_weights', 'example_ids', 'sparse_features',\n        'dense_features'\n    ], examples)\n    self._assert_list(['sparse_features', 'dense_features'], examples)\n\n    self._assert_specified(\n        ['sparse_features_weights', 'dense_features_weights'], variables)\n    self._assert_list(['sparse_features_weights', 'dense_features_weights'],\n                      variables)\n\n    self._assert_specified([\n        'loss_type', 'symmetric_l2_regularization',\n        'symmetric_l1_regularization'\n    ], options)\n\n    if options['symmetric_l2_regularization'] <= 0.0:\n      raise ValueError('symmetric_l2_regularization should be positive.')\n    if options['symmetric_l2_regularization'] <= 1.0:\n      tf.compat.v1.logging.warn(\n          'symmetric_l2_regularization for SDCA should typically be '\n          'larger than for online optimization methods. Recommended '\n          'value is of the order of the average L2 norm of the '\n          'training examples.')\n    if options['symmetric_l1_regularization'] < 0.0:\n      raise ValueError('symmetric_l1_regularization should be non-negative.')\n\n    self._examples = examples\n    self._variables = variables\n    self._options = options\n    self._create_slots()\n    self._hashtable = _ShardedMutableDenseHashTable(\n        key_dtype=tf.dtypes.int64,\n        value_dtype=tf.dtypes.float32,\n        num_shards=self._num_table_shards(),\n        default_value=[0.0, 0.0, 0.0, 0.0],\n        # SdcaFprint never returns 0 or 1 for the low64 bits, so this a safe\n        # empty_key (that will never collide with actual payloads).\n        empty_key=[0, 0],\n        deleted_key=[1, 1])\n\n    tf.compat.v1.summary.scalar('approximate_duality_gap',\n                                self.approximate_duality_gap())\n    tf.compat.v1.summary.scalar('examples_seen', self._hashtable.size())\n\n  def _symmetric_l1_regularization(self):\n    return self._options['symmetric_l1_regularization']\n\n  def _symmetric_l2_regularization(self):\n    return self._options['symmetric_l2_regularization']\n\n  def _num_loss_partitions(self):\n    # Number of partitions of the global objective.\n    return self._options.get('num_loss_partitions', 1)\n\n  def _adaptive(self):\n    # Perform adaptive sampling.\n    return self._options.get('adaptive', True)\n\n  def _num_table_shards(self):\n    # Number of hash table shards.\n    # Return 1 if not specified or if the value is 'None'\n    num_shards = self._options.get('num_table_shards')\n    return 1 if num_shards is None else num_shards\n\n  def _create_slots(self):\n    \"\"\"Make unshrunk internal variables (slots).\"\"\"\n    # Unshrunk variables have the updates before applying L1 regularization.\n    # Each unshrunk slot variable is either a `Variable` or list of\n    # `Variable`, depending on the value of its corresponding primary variable.\n    # We avoid using `PartitionedVariable` for the unshrunk slots since we do\n    # not need any of the extra information.\n    self._slots = collections.defaultdict(list)\n    for name in ['sparse_features_weights', 'dense_features_weights']:\n      for var in self._variables[name]:\n        # Our primary variable may be either a PartitionedVariable, or a list\n        # of Variables (each representing a partition).\n        if (isinstance(var, var_ops.PartitionedVariable) or\n            isinstance(var, list)):\n          var_list = []\n          for v in var:\n            with ops.colocate_with(v):\n              slot_var = tf.Variable(\n                  initial_value=tf.compat.v1.zeros_like(\n                      tf.cond(\n                          tf.compat.v1.is_variable_initialized(v),\n                          v.read_value,\n                          lambda: v.initial_value),\n                      tf.dtypes.float32),\n                  name=v.op.name + '_unshrunk')\n              var_list.append(slot_var)\n          self._slots['unshrunk_' + name].append(var_list)\n        else:\n          with tf.compat.v1.device(var.device):\n            self._slots['unshrunk_' + name].append(\n                tf.Variable(\n                    tf.compat.v1.zeros_like(\n                        tf.cond(\n                            tf.compat.v1.is_variable_initialized(var),\n                            var.read_value,\n                            lambda: var.initial_value),\n                        tf.dtypes.float32),\n                    name=var.op.name + '_unshrunk'))\n\n  def _assert_specified(self, items, check_in):\n    for x in items:\n      if check_in[x] is None:\n        raise ValueError(check_in[x] + ' must be specified.')\n\n  def _assert_list(self, items, check_in):\n    for x in items:\n      if not isinstance(check_in[x], list):\n        raise ValueError(x + ' must be a list.')\n\n  def _var_to_list(self, var):\n    \"\"\"Wraps var in a list if it is not a list or PartitionedVariable.\"\"\"\n    if not isinstance(var, (list, var_ops.PartitionedVariable)):\n      var = [var]\n    return var\n\n  def _l1_loss(self):\n    \"\"\"Computes the (un-normalized) l1 loss of the model.\"\"\"\n    with name_scope('sdca/l1_loss'):\n      sums = []\n      for name in ['sparse_features_weights', 'dense_features_weights']:\n        for var in self._variables[name]:\n          for v in self._var_to_list(var):\n            weights = internal_convert_to_tensor(v)\n            with tf.compat.v1.device(weights.device):\n              sums.append(\n                  tf.math.reduce_sum(\n                      tf.math.abs(tf.cast(weights, tf.dtypes.float64))))\n      # SDCA L1 regularization cost is: l1 * sum(|weights|)\n      return self._symmetric_l1_regularization() * tf.math.add_n(sums)\n\n  def _l2_loss(self):\n    \"\"\"Computes the (un-normalized) l2 loss of the model.\"\"\"\n    with name_scope('sdca/l2_loss'):\n      sums = []\n      for name in ['sparse_features_weights', 'dense_features_weights']:\n        for var in self._variables[name]:\n          for v in self._var_to_list(var):\n            weights = internal_convert_to_tensor(v)\n            with tf.compat.v1.device(weights.device):\n              sums.append(\n                  tf.math.reduce_sum(\n                      tf.math.square(tf.cast(weights, tf.dtypes.float64))))\n      # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2\n      return self._symmetric_l2_regularization() * tf.math.add_n(sums) / 2.0\n\n  def _convert_n_to_tensor(self, input_list, as_ref=False):\n    \"\"\"Converts input list to a set of tensors.\"\"\"\n    # input_list can be a list of Variables (that are implicitly partitioned),\n    # in which case the underlying logic in internal_convert_to_tensor will not\n    # concatenate the partitions together.  This method takes care of the\n    # concatenating (we only allow partitioning on the first axis).\n    output_list = []\n    for x in input_list:\n      tensor_to_convert = x\n      if isinstance(x, list) or isinstance(x, var_ops.PartitionedVariable):\n        # We only allow for partitioning on the first axis.\n        tensor_to_convert = tf.concat(x, axis=0)\n      output_list.append(\n          internal_convert_to_tensor(tensor_to_convert, as_ref=as_ref))\n    return output_list\n\n  def _get_first_dimension_size_statically(self, w, num_partitions):\n    \"\"\"Compute the static size of the first dimension for a sharded variable.\"\"\"\n    dim_0_size = w[0].get_shape()[0]\n    for p in range(1, num_partitions):\n      dim_0_size += w[p].get_shape()[0]\n    return dim_0_size\n\n  def _linear_predictions(self, examples):\n    \"\"\"Returns predictions of the form w*x.\n\n    Args:\n      examples: Examples to compute predictions on.\n    \"\"\"\n    with name_scope('sdca/prediction'):\n      batch_size = tf.compat.v1.shape(examples['example_ids'])[0]\n\n      predictions = tf.zeros([batch_size])\n      sparse_variables = self._convert_n_to_tensor(\n          self._variables['sparse_features_weights'])\n      for sfc, sv in zip(examples['sparse_features'], sparse_variables):\n        unpadded_dot_product = tf.math.segment_sum(\n            tf.math.multiply(\n                tf.compat.v1.gather(sv, sfc.feature_indices),\n                sfc.feature_values), sfc.example_indices)\n        predictions += tf.compat.v1.pad(\n            unpadded_dot_product,\n            [[0, batch_size - tf.compat.v1.shape(unpadded_dot_product)[0]]])\n\n      dense_features = self._convert_n_to_tensor(examples['dense_features'])\n      dense_variables = self._convert_n_to_tensor(\n          self._variables['dense_features_weights'])\n      for i in range(len(dense_variables)):\n        predictions += tf.compat.v1.squeeze(\n            tf.linalg.matmul(dense_features[i],\n                             tf.compat.v1.expand_dims(dense_variables[i], -1)))\n\n    return predictions\n\n  def predictions(self, examples):\n    \"\"\"Add operations to compute predictions by the model.\n\n    If logistic_loss is being used, predicted probabilities are returned.\n    If poisson_loss is being used, predictions are exponentiated.\n    Otherwise, (raw) linear predictions (w*x) are returned.\n\n    Args:\n      examples: Examples to compute predictions on.\n\n    Returns:\n      An Operation that computes the predictions for examples.\n\n    Raises:\n      ValueError: if examples are not well defined.\n    \"\"\"\n    self._assert_specified(\n        ['example_weights', 'sparse_features', 'dense_features'], examples)\n    self._assert_list(['sparse_features', 'dense_features'], examples)\n\n    result = self._linear_predictions(examples)\n    if self._options['loss_type'] == 'logistic_loss':\n      # Convert logits to probability for logistic loss predictions.\n      with name_scope('sdca/logistic_prediction'):\n        result = tf.math.sigmoid(result)\n    elif self._options['loss_type'] == 'poisson_loss':\n      # Exponeniate the prediction for poisson loss predictions.\n      with name_scope('sdca/poisson_prediction'):\n        result = tf.math.exp(result)\n    return result\n\n  def _get_partitioned_update_ops(self, v_num, num_partitions_by_var,\n                                  p_assignments_by_var, gather_ids_by_var,\n                                  weights, full_update, p_assignments,\n                                  num_partitions):\n    \"\"\"Get updates for partitioned variables.\"\"\"\n    num_partitions = num_partitions_by_var[v_num]\n    p_assignments = p_assignments_by_var[v_num]\n    gather_ids = gather_ids_by_var[v_num]\n    updates = tf.dynamic_partition(full_update, p_assignments, num_partitions)\n    update_ops = []\n    for p in range(num_partitions):\n      with ops.colocate_with(weights[p]):\n        result = tf.compat.v1.scatter_add(weights[p], gather_ids[p], updates[p])\n      update_ops.append(result)\n    return update_ops\n\n  def minimize(self, global_step=None, name=None):\n    \"\"\"Add operations to train a linear model by minimizing the loss function.\n\n    Args:\n      global_step: Optional `Variable` to increment by one after the variables\n        have been updated.\n      name: Optional name for the returned operation.\n\n    Returns:\n      An Operation that updates the variables passed in the constructor.\n    \"\"\"\n    # Technically, the op depends on a lot more than the variables,\n    # but we'll keep the list short.\n    with name_scope(name, 'sdca/minimize'):\n      sparse_example_indices = []\n      sparse_feature_indices = []\n      sparse_features_values = []\n      for sf in self._examples['sparse_features']:\n        sparse_example_indices.append(sf.example_indices)\n        sparse_feature_indices.append(sf.feature_indices)\n        # If feature values are missing, sdca assumes a value of 1.0f.\n        if sf.feature_values is not None:\n          sparse_features_values.append(sf.feature_values)\n\n      example_ids_hashed = tf.compat.v1.train.sdca_fprint(\n          internal_convert_to_tensor(self._examples['example_ids']))\n      example_state_data = self._hashtable.lookup(example_ids_hashed)\n      # Solver returns example_state_update, new delta sparse_feature_weights\n      # and delta dense_feature_weights.\n\n      sparse_weights = []\n      sparse_indices = []\n      # If we have partitioned variables, keep a few dictionaries of Tensors\n      # around that we need for the assign_add after the op call to\n      # gen_sdca_ops.sdca_optimizer().  These are keyed because we may have a\n      # mix of partitioned and un-partitioned variables.\n      num_partitions_by_var = {}\n      p_assignments_by_var = {}\n      gather_ids_by_var = {}\n      for v_num, (w, i) in enumerate(\n          zip(self._slots['unshrunk_sparse_features_weights'],\n              sparse_feature_indices)):\n        # Append the sparse_indices (in full-variable space).\n        sparse_idx = tf.cast(\n            tf.unique(tf.cast(i, tf.dtypes.int32))[0], tf.dtypes.int64)\n        sparse_indices.append(sparse_idx)\n        if isinstance(w, list) or isinstance(w, var_ops.PartitionedVariable):\n          num_partitions = len(w)\n          flat_ids = tf.reshape(sparse_idx, [-1])\n          # We use div partitioning, which is easiest to support downstream.\n          # Compute num_total_ids as the sum of dim-0 of w, then assign\n          # to partitions based on a constant number of ids per partition.\n          # Optimize if we already know the full shape statically.\n          dim_0_size = self._get_first_dimension_size_statically(\n              w, num_partitions)\n\n          if tf.compat.dimension_value(dim_0_size):\n            num_total_ids = tf.constant(\n                tf.compat.dimension_value(dim_0_size), flat_ids.dtype)\n          else:\n            dim_0_sizes = []\n            for p in range(num_partitions):\n              if tf.compat.dimension_value(w[p].shape[0]) is not None:\n                dim_0_sizes.append(tf.compat.dimension_value(w[p].shape[0]))\n              else:\n                with ops.colocate_with(w[p]):\n                  dim_0_sizes.append(tf.compat.v1.shape(w[p])[0])\n            num_total_ids = tf.math.reduce_sum(\n                tf.cast(tf.stack(dim_0_sizes), flat_ids.dtype))\n          ids_per_partition = num_total_ids // num_partitions\n          extras = num_total_ids % num_partitions\n\n          p_assignments = tf.math.maximum(flat_ids // (ids_per_partition + 1),\n                                          (flat_ids - extras) //\n                                          ids_per_partition)\n\n          # Emulate a conditional using a boolean indicator tensor\n          new_ids = tf.where(p_assignments < extras,\n                             flat_ids % (ids_per_partition + 1),\n                             (flat_ids - extras) % ids_per_partition)\n\n          # Cast partition assignments to int32 for use in dynamic_partition.\n          # There really should not be more than 2^32 partitions.\n          p_assignments = tf.cast(p_assignments, tf.dtypes.int32)\n          # Partition list of ids based on assignments into num_partitions\n          # separate lists.\n          gather_ids = tf.dynamic_partition(new_ids, p_assignments,\n                                            num_partitions)\n          # Add these into the dictionaries for use in the later update.\n          num_partitions_by_var[v_num] = num_partitions\n          p_assignments_by_var[v_num] = p_assignments\n          gather_ids_by_var[v_num] = gather_ids\n\n          # Gather the weights from each partition.\n          partition_gathered_weights = []\n          for p in range(num_partitions):\n            with ops.colocate_with(w[p]):\n              partition_gathered_weights.append(\n                  tf.compat.v1.gather(w[p], gather_ids[p]))\n\n          # Stitch the weights back together in the same order they were before\n          # we dynamic_partitioned them.\n          condition_indices = tf.dynamic_partition(\n              tf.range(tf.compat.v1.shape(new_ids)[0]), p_assignments,\n              num_partitions)\n          batch_gathered_weights = tf.dynamic_stitch(\n              condition_indices, partition_gathered_weights)\n        else:\n          w_as_tensor = internal_convert_to_tensor(w)\n          with tf.compat.v1.device(w_as_tensor.device):\n            batch_gathered_weights = tf.compat.v1.gather(\n                w_as_tensor, sparse_idx)\n        sparse_weights.append(batch_gathered_weights)\n\n      if tf.compat.forward_compatible(year=2018, month=10, day=30):\n        esu, sfw, dfw = gen_sdca_ops.sdca_optimizer_v2(\n            sparse_example_indices,\n            sparse_feature_indices,\n            sparse_features_values,\n            self._convert_n_to_tensor(self._examples['dense_features']),\n            internal_convert_to_tensor(self._examples['example_weights']),\n            internal_convert_to_tensor(self._examples['example_labels']),\n            sparse_indices,\n            sparse_weights,\n            self._convert_n_to_tensor(\n                self._slots['unshrunk_dense_features_weights']),\n            example_state_data,\n            loss_type=self._options['loss_type'],\n            l1=self._symmetric_l1_regularization(),\n            l2=self._symmetric_l2_regularization(),\n            num_loss_partitions=self._num_loss_partitions(),\n            num_inner_iterations=1,\n            adaptive=self._adaptive())\n      else:\n        esu, sfw, dfw = tf.compat.v1.train.sdca_optimizer(\n            sparse_example_indices,\n            sparse_feature_indices,\n            sparse_features_values,\n            self._convert_n_to_tensor(self._examples['dense_features']),\n            internal_convert_to_tensor(self._examples['example_weights']),\n            internal_convert_to_tensor(self._examples['example_labels']),\n            sparse_indices,\n            sparse_weights,\n            self._convert_n_to_tensor(\n                self._slots['unshrunk_dense_features_weights']),\n            example_state_data,\n            loss_type=self._options['loss_type'],\n            l1=self._symmetric_l1_regularization(),\n            l2=self._symmetric_l2_regularization(),\n            num_loss_partitions=self._num_loss_partitions(),\n            num_inner_iterations=1,\n            adaptative=self._adaptive())\n\n      with tf.control_dependencies([esu]):\n        update_ops = [self._hashtable.insert(example_ids_hashed, esu)]\n        # Update the weights before the proximal step.\n        for v_num, (w, i, u) in enumerate(\n            zip(self._slots['unshrunk_sparse_features_weights'], sparse_indices,\n                sfw)):\n          if (isinstance(w, var_ops.PartitionedVariable) or\n              isinstance(w, list)):\n            update_ops += self._get_partitioned_update_ops(\n                v_num, num_partitions_by_var, p_assignments_by_var,\n                gather_ids_by_var, w, u, p_assignments, num_partitions)\n          else:\n            update_ops.append(tf.compat.v1.scatter_add(w, i, u))\n        for w, u in zip(self._slots['unshrunk_dense_features_weights'], dfw):\n          if (isinstance(w, var_ops.PartitionedVariable) or\n              isinstance(w, list)):\n            split_updates = tf.split(\n                u, num_or_size_splits=[v.shape.as_list()[0] for v in w])\n            for v, split_update in zip(w, split_updates):\n              update_ops.append(tf.compat.v1.assign_add(v, split_update))\n          else:\n            update_ops.append(tf.compat.v1.assign_add(w, u))\n      if global_step is None:\n        return tf.group(*update_ops)\n      with tf.control_dependencies(update_ops):\n        return tf.compat.v1.assign_add(global_step, 1, name=name).op\n\n  def update_weights(self, train_op):\n    \"\"\"Updates the model weights.\n\n    This function must be called on at least one worker after `minimize`.\n    In distributed training this call can be omitted on non-chief workers to\n    speed up training.\n\n    Args:\n      train_op: The operation returned by the `minimize` call.\n\n    Returns:\n      An Operation that updates the model weights.\n    \"\"\"\n    with tf.control_dependencies([train_op]):\n      update_ops = []\n      # Copy over unshrunk weights to user provided variables.\n      for name in ['sparse_features_weights', 'dense_features_weights']:\n        for var, slot_var in zip(self._variables[name],\n                                 self._slots['unshrunk_' + name]):\n          for v, sv in zip(self._var_to_list(var), self._var_to_list(slot_var)):\n            update_ops.append(v.assign(sv))\n\n    # Apply proximal step.\n    if self._symmetric_l1_regularization() > 0:\n      shrinkage = (\n          self._symmetric_l1_regularization() /\n          self._symmetric_l2_regularization())\n      with tf.control_dependencies(update_ops):\n        update_ops = []\n        for name in ['sparse_features_weights', 'dense_features_weights']:\n          for var in self._variables[name]:\n            for v in self._var_to_list(var):\n              with tf.compat.v1.device(v.device):\n                v_shrunk = tf.math.sign(v) * tf.math.maximum(\n                    0.0,\n                    tf.math.abs(v) - shrinkage)\n                update_ops.append(v.assign(v_shrunk))\n        return tf.group(*update_ops)\n    else:\n      return tf.group(*update_ops)\n\n  def approximate_duality_gap(self):\n    \"\"\"Add operations to compute the approximate duality gap.\n\n    Returns:\n      An Operation that computes the approximate duality gap over all\n      examples.\n    \"\"\"\n    with name_scope('sdca/approximate_duality_gap'):\n      _, values_list = self._hashtable.export_sharded()\n      shard_sums = []\n      for values in values_list:\n        with tf.compat.v1.device(values.device):\n          # For large tables to_double() below allocates a large temporary\n          # tensor that is freed once the sum operation completes. To reduce\n          # peak memory usage in cases where we have multiple large tables on a\n          # single device, we serialize these operations.\n          # Note that we need double precision to get accurate results.\n          with tf.control_dependencies(shard_sums):\n            shard_sums.append(\n                tf.math.reduce_sum(tf.cast(values, dtype=tf.dtypes.float64), 0))\n      summed_values = tf.math.add_n(shard_sums)\n\n      primal_loss = summed_values[1]\n      dual_loss = summed_values[2]\n      example_weights = summed_values[3]\n      # Note: we return NaN if there are no weights or all weights are 0, e.g.\n      # if no examples have been processed\n      return (primal_loss + dual_loss + self._l1_loss() +\n              (2.0 * self._l2_loss())) / example_weights\n\n  def unregularized_loss(self, examples):\n    \"\"\"Add operations to compute the loss (without the regularization loss).\n\n    Args:\n      examples: Examples to compute unregularized loss on.\n\n    Returns:\n      An Operation that computes mean (unregularized) loss for given set of\n      examples.\n\n    Raises:\n      ValueError: if examples are not well defined.\n    \"\"\"\n    self._assert_specified([\n        'example_labels', 'example_weights', 'sparse_features', 'dense_features'\n    ], examples)\n    self._assert_list(['sparse_features', 'dense_features'], examples)\n    with name_scope('sdca/unregularized_loss'):\n      predictions = tf.cast(\n          self._linear_predictions(examples), tf.dtypes.float64)\n      labels = tf.cast(\n          internal_convert_to_tensor(examples['example_labels']),\n          tf.dtypes.float64)\n      weights = tf.cast(\n          internal_convert_to_tensor(examples['example_weights']),\n          tf.dtypes.float64)\n\n      if self._options['loss_type'] == 'logistic_loss':\n        return tf.math.reduce_sum(\n            tf.math.multiply(\n                sigmoid_cross_entropy_with_logits(\n                    labels=labels, logits=predictions),\n                weights)) / tf.math.reduce_sum(weights)\n\n      if self._options['loss_type'] == 'poisson_loss':\n        return tf.math.reduce_sum(\n            tf.math.multiply(\n                log_poisson_loss(targets=labels, log_input=predictions),\n                weights)) / tf.math.reduce_sum(weights)\n\n      if self._options['loss_type'] in ['hinge_loss', 'smooth_hinge_loss']:\n        # hinge_loss = max{0, 1 - y_i w*x} where y_i \\in {-1, 1}. So, we need to\n        # first convert 0/1 labels into -1/1 labels.\n        all_ones = tf.compat.v1.ones_like(predictions)\n        adjusted_labels = tf.math.subtract(2 * labels, all_ones)\n        # Tensor that contains (unweighted) error (hinge loss) per\n        # example.\n        error = tf.nn.relu(\n            tf.math.subtract(all_ones,\n                             tf.math.multiply(adjusted_labels, predictions)))\n        weighted_error = tf.math.multiply(error, weights)\n        return tf.math.reduce_sum(weighted_error) / tf.math.reduce_sum(weights)\n\n      # squared loss\n      err = tf.math.subtract(labels, predictions)\n\n      weighted_squared_err = tf.math.multiply(tf.math.square(err), weights)\n      # SDCA squared loss function is sum(err^2) / (2*sum(weights))\n      return (tf.math.reduce_sum(weighted_squared_err) /\n              (2.0 * tf.math.reduce_sum(weights)))\n\n  def regularized_loss(self, examples):\n    \"\"\"Add operations to compute the loss with regularization loss included.\n\n    Args:\n      examples: Examples to compute loss on.\n\n    Returns:\n      An Operation that computes mean (regularized) loss for given set of\n      examples.\n    Raises:\n      ValueError: if examples are not well defined.\n    \"\"\"\n    self._assert_specified([\n        'example_labels', 'example_weights', 'sparse_features', 'dense_features'\n    ], examples)\n    self._assert_list(['sparse_features', 'dense_features'], examples)\n    with name_scope('sdca/regularized_loss'):\n      weights = internal_convert_to_tensor(examples['example_weights'])\n      return ((self._l1_loss() + self._l2_loss()) /\n              tf.math.reduce_sum(tf.cast(weights, tf.dtypes.float64)) +\n              self.unregularized_loss(examples))\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_optimizer/python/utils/sdca_ops_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for SdcaModel.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport random\nimport threading\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.python.eager import context\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.framework.test_util import TensorFlowTestCase\nfrom tensorflow.python.platform import googletest\nfrom tensorflow_estimator.python.estimator.canned.linear_optimizer.python.utils.sdca_ops import _SDCAModel\nfrom tensorflow_estimator.python.estimator.canned.linear_optimizer.python.utils.sdca_ops import _SparseFeatureColumn\n\n_MAX_ITERATIONS = 100\n_SHARD_NUMBERS = [None, 1, 3]\n_NUM_LOSS_PARTITIONS = [4]\n\n\ndef make_example_proto(feature_dict, target, value=1.0):\n  e = example_pb2.Example()\n  features = e.features\n\n  features.feature['target'].float_list.value.append(target)\n\n  for key, values in feature_dict.items():\n    features.feature[key + '_indices'].int64_list.value.extend(values)\n    features.feature[key + '_values'].float_list.value.extend([value] *\n                                                              len(values))\n\n  return e\n\n\ndef make_example_dict(example_protos, example_weights):\n\n  def parse_examples(example_protos):\n    features = {\n        'target':\n            tf.io.FixedLenFeature(\n                shape=[1], dtype=tf.dtypes.float32, default_value=0),\n        'age_indices':\n            tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'age_values':\n            tf.io.VarLenFeature(dtype=tf.dtypes.float32),\n        'gender_indices':\n            tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'gender_values':\n            tf.io.VarLenFeature(dtype=tf.dtypes.float32)\n    }\n    return tf.compat.v1.io.parse_example(\n        [e.SerializeToString() for e in example_protos], features)\n\n  parsed = parse_examples(example_protos)\n  sparse_features = [\n      _SparseFeatureColumn(\n          tf.reshape(\n              tf.split(\n                  value=parsed['age_indices'].indices,\n                  num_or_size_splits=2,\n                  axis=1)[0], [-1]),\n          tf.reshape(parsed['age_indices'].values, [-1]),\n          tf.reshape(parsed['age_values'].values, [-1])),\n      _SparseFeatureColumn(\n          tf.reshape(\n              tf.split(\n                  value=parsed['gender_indices'].indices,\n                  num_or_size_splits=2,\n                  axis=1)[0], [-1]),\n          tf.reshape(parsed['gender_indices'].values, [-1]),\n          tf.reshape(parsed['gender_values'].values, [-1]))\n  ]\n  return dict(\n      sparse_features=sparse_features,\n      dense_features=[],\n      example_weights=example_weights,\n      example_labels=tf.reshape(parsed['target'], [-1]),\n      example_ids=['%d' % i for i in range(0, len(example_protos))])\n\n\ndef make_random_examples_and_variables_dicts(num_examples, dim, num_non_zero):\n  random.seed(1)\n\n  sparse_features = [\n      _SparseFeatureColumn(\n          [i for i in range(num_examples) for _ in range(num_non_zero)], [\n              i for _ in range(num_examples)\n              for i in random.sample(range(dim), num_non_zero)\n          ], [num_non_zero**(-0.5) for _ in range(num_examples * num_non_zero)])\n  ]\n  examples_dict = dict(\n      sparse_features=sparse_features,\n      dense_features=[],\n      example_weights=[random.random() for _ in range(num_examples)],\n      example_labels=[\n          1. if random.random() > 0.5 else 0. for _ in range(num_examples)\n      ],\n      example_ids=[str(i) for i in range(num_examples)])\n\n  weights = tf.compat.v1.Variable(tf.zeros([dim], dtype=tf.dtypes.float32))\n  variables_dict = dict(\n      sparse_features_weights=[weights], dense_features_weights=[])\n\n  return examples_dict, variables_dict\n\n\ndef make_variable_dict(max_age, max_gender, num_shards=None, partitioned=False):\n  # TODO(dbaylor):  Figure out how to derive max_age & max_gender from\n  # examples_dict.\n  partitioner = None\n  if partitioned:\n    partitioner = tf.compat.v1.fixed_size_partitioner(num_shards=2, axis=0)\n  with tf.compat.v1.variable_scope(\n      name_or_scope=('variables/shard_{}'.format(num_shards)\n                     if num_shards else 'variables'),\n      partitioner=partitioner):\n    age_weights = tf.compat.v1.get_variable(\n        name='age',\n        initializer=tf.zeros([max_age + 1], dtype=tf.dtypes.float32))\n    gender_weights = tf.compat.v1.get_variable(\n        name='gender',\n        initializer=tf.zeros([max_gender + 1], dtype=tf.dtypes.float32))\n  return dict(\n      sparse_features_weights=[age_weights, gender_weights],\n      dense_features_weights=[])\n\n\ndef make_dense_examples_and_variables_dicts(dense_features_values, weights,\n                                            labels):\n  \"\"\"Creates examples and variables dictionaries for dense features.\n\n  Variables shapes are inferred from the list of dense feature values passed as\n  argument.\n\n  Args:\n    dense_features_values: The values of the dense features\n    weights: The example weights.\n    labels: The example labels.\n\n  Returns:\n    One dictionary for the examples and one for the variables.\n  \"\"\"\n  dense_tensors = []\n  dense_weights = []\n  for dense_feature in dense_features_values:\n    dense_tensor = ops.convert_to_tensor(dense_feature, dtype=tf.dtypes.float32)\n    check_shape_op = tf.debugging.Assert(\n        tf.math.less_equal(tf.rank(dense_tensor), 2),\n        ['dense_tensor shape must be [batch_size, dimension] or [batch_size]'])\n    # Reshape to [batch_size, dense_column_dimension].\n    with tf.control_dependencies([check_shape_op]):\n      dense_tensor = tf.reshape(dense_tensor,\n                                [dense_tensor.get_shape().as_list()[0], -1])\n    dense_tensors.append(dense_tensor)\n    # Add variables of shape [feature_column_dimension].\n    dense_weights.append(\n        tf.compat.v1.Variable(\n            tf.zeros([dense_tensor.get_shape().as_list()[1]],\n                     dtype=tf.dtypes.float32)))\n\n  examples_dict = dict(\n      sparse_features=[],\n      dense_features=dense_tensors,\n      example_weights=weights,\n      example_labels=labels,\n      example_ids=['%d' % i for i in range(0, len(labels))])\n  variables_dict = dict(\n      sparse_features_weights=[], dense_features_weights=dense_weights)\n\n  return examples_dict, variables_dict\n\n\ndef get_binary_predictions_for_logistic(predictions, cutoff=0.5):\n  return tf.cast(\n      tf.math.greater_equal(predictions,\n                            tf.compat.v1.ones_like(predictions) * cutoff),\n      dtype=tf.dtypes.int32)\n\n\ndef get_binary_predictions_for_hinge(predictions):\n  return tf.cast(\n      tf.math.greater_equal(predictions, tf.compat.v1.zeros_like(predictions)),\n      dtype=tf.dtypes.int32)\n\n\n# TODO(pmol): Refactor tests to avoid repetition of boilerplate code.\n\n\nclass _SDCAModelTest(TensorFlowTestCase):\n  \"\"\"Base SDCA optimizer test class for any loss type.\"\"\"\n\n  def _single_threaded_test_session(self):\n    config = tf.compat.v1.ConfigProto(\n        inter_op_parallelism_threads=1, intra_op_parallelism_threads=1)\n    return self.test_session(use_gpu=False, config=config)\n\n\n# ResourceVariable only runs in graph mode\n@test_util.deprecated_graph_mode_only\nclass SdcaWithLogisticLossTest(_SDCAModelTest):\n  \"\"\"SDCA optimizer test class for logistic loss.\"\"\"\n\n  def testSimple(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0]\n    for num_shards in _SHARD_NUMBERS:\n      with self._single_threaded_test_session():\n        examples = make_example_dict(example_protos, example_weights)\n        variables = make_variable_dict(1, 1, num_shards)\n        options = dict(\n            symmetric_l2_regularization=1,\n            symmetric_l1_regularization=0,\n            num_table_shards=num_shards,\n            loss_type='logistic_loss')\n\n        lr = _SDCAModel(examples, variables, options)\n        tf.compat.v1.initializers.global_variables().run()\n        unregularized_loss = lr.unregularized_loss(examples)\n        loss = lr.regularized_loss(examples)\n        predictions = lr.predictions(examples)\n        self.assertAllClose(0.693147, unregularized_loss.eval())\n        self.assertAllClose(0.693147, loss.eval())\n        train_op = lr.minimize()\n        for _ in range(_MAX_ITERATIONS):\n          train_op.run()\n        lr.update_weights(train_op).run()\n        # The high tolerance in unregularized_loss comparisons is due to the\n        # fact that it's possible to trade off unregularized_loss vs.\n        # regularization and still have a sum that is quite close to the\n        # optimal regularized_loss value.  SDCA's duality gap only ensures that\n        # the regularized_loss is within 0.01 of optimal.\n        # 0.525457 is the optimal regularized_loss.\n        # 0.411608 is the unregularized_loss at that optimum.\n        self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05)\n        self.assertAllClose(0.525457, loss.eval(), atol=0.01)\n        predicted_labels = get_binary_predictions_for_logistic(predictions)\n        self.assertAllEqual([0, 1], predicted_labels.eval())\n        self.assertAllClose(\n            0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)\n\n  def testPartitionedPrimals(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0]\n    for num_shards in _SHARD_NUMBERS:\n      with self._single_threaded_test_session():\n        examples = make_example_dict(example_protos, example_weights)\n        variables = make_variable_dict(1, 1, num_shards, partitioned=True)\n        options = dict(\n            symmetric_l2_regularization=1,\n            symmetric_l1_regularization=0,\n            num_table_shards=num_shards,\n            loss_type='logistic_loss')\n\n        lr = _SDCAModel(examples, variables, options)\n        tf.compat.v1.initializers.global_variables().run()\n        unregularized_loss = lr.unregularized_loss(examples)\n        loss = lr.regularized_loss(examples)\n        predictions = lr.predictions(examples)\n        self.assertAllClose(0.693147, unregularized_loss.eval())\n        self.assertAllClose(0.693147, loss.eval())\n        train_op = lr.minimize()\n        for _ in range(_MAX_ITERATIONS):\n          train_op.run()\n        lr.update_weights(train_op).run()\n        # The high tolerance in unregularized_loss comparisons is due to the\n        # fact that it's possible to trade off unregularized_loss vs.\n        # regularization and still have a sum that is quite close to the\n        # optimal regularized_loss value.  SDCA's duality gap only ensures that\n        # the regularized_loss is within 0.01 of optimal.\n        # 0.525457 is the optimal regularized_loss.\n        # 0.411608 is the unregularized_loss at that optimum.\n        self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05)\n        self.assertAllClose(0.525457, loss.eval(), atol=0.01)\n        predicted_labels = get_binary_predictions_for_logistic(predictions)\n        self.assertAllEqual([0, 1], predicted_labels.eval())\n        self.assertAllClose(\n            0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)\n\n  def testSomePartitionedPrimals(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [0],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0]\n    for num_shards in _SHARD_NUMBERS:\n      with self._single_threaded_test_session():\n        examples = make_example_dict(example_protos, example_weights)\n        # Explicitly make age a [1]-shaped Variable (which cannot be\n        # partitioned), while making gender a PartitionedVariable.\n        age_weights = tf.compat.v1.Variable(\n            tf.zeros([1], dtype=tf.dtypes.float32))\n        with tf.compat.v1.variable_scope(\n            name_or_scope=('variables/shard_{}'.format(num_shards)\n                           if num_shards else 'variables'),\n            partitioner=tf.compat.v1.fixed_size_partitioner(\n                num_shards=2, axis=0)):\n          gender_weights = tf.compat.v1.get_variable(\n              name='gender', initializer=tf.zeros([2], dtype=tf.dtypes.float32))\n        variables = dict(\n            sparse_features_weights=[age_weights, gender_weights],\n            dense_features_weights=[])\n        options = dict(\n            symmetric_l2_regularization=1,\n            symmetric_l1_regularization=0,\n            num_table_shards=num_shards,\n            loss_type='logistic_loss')\n\n        lr = _SDCAModel(examples, variables, options)\n        tf.compat.v1.initializers.global_variables().run()\n        unregularized_loss = lr.unregularized_loss(examples)\n        loss = lr.regularized_loss(examples)\n        predictions = lr.predictions(examples)\n        self.assertAllClose(0.693147, unregularized_loss.eval())\n        self.assertAllClose(0.693147, loss.eval())\n        train_op = lr.minimize()\n        for _ in range(_MAX_ITERATIONS):\n          train_op.run()\n        lr.update_weights(train_op).run()\n        # The high tolerance in unregularized_loss comparisons is due to the\n        # fact that it's possible to trade off unregularized_loss vs.\n        # regularization and still have a sum that is quite close to the\n        # optimal regularized_loss value.  SDCA's duality gap only ensures that\n        # the regularized_loss is within 0.01 of optimal.\n        # 0.525457 is the optimal regularized_loss.\n        # 0.593014 is the unregularized_loss at that optimum.\n        self.assertAllClose(0.512591, unregularized_loss.eval(), atol=0.05)\n        self.assertAllClose(0.593014, loss.eval(), atol=0.01)\n        predicted_labels = get_binary_predictions_for_logistic(predictions)\n        self.assertAllEqual([0, 1], predicted_labels.eval())\n        self.assertAllClose(\n            0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)\n\n  def testSparseRandom(self):\n    dim = 20\n    num_examples = 1000\n    # Number of non-zero features per example.\n    non_zeros = 10\n    # Setup test data.\n    with self._single_threaded_test_session():\n      examples, variables = make_random_examples_and_variables_dicts(\n          num_examples, dim, non_zeros)\n      options = dict(\n          symmetric_l2_regularization=.1,\n          symmetric_l1_regularization=0,\n          num_table_shards=1,\n          adaptive=False,\n          loss_type='logistic_loss')\n\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      train_op = lr.minimize()\n      for _ in range(10):\n        train_op.run()\n      lr.update_weights(train_op).run()\n      self.assertNear(0.0, lr.approximate_duality_gap().eval(), err=1e-2)\n\n  def testSparseDuplicate(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0] * 5,\n            'gender': [0] * 5\n        }, 0),\n        make_example_proto({\n            'age': [1] * 5,\n            'gender': [1] * 5\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(1, 1)\n      options = dict(\n          symmetric_l2_regularization=1,\n          symmetric_l1_regularization=0,\n          loss_type='logistic_loss')\n\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      train_op = lr.minimize()\n      with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 'Duplicate'):\n        train_op.run()\n\n  def testDistributedSimple(self):\n    # Distributed SDCA may not converge if the workers update concurrently the\n    # same example. In this test the examples are partitioned across workers.\n    # The examples are the same for all workers, just the example_ids are\n    # different.\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0]\n    examples = make_example_dict(example_protos, example_weights)\n    example_ids = tf.compat.v1.placeholder(\n        tf.dtypes.string, shape=(len(example_weights),))\n    examples['example_ids'] = example_ids\n    variables = make_variable_dict(1, 1)\n    # We need each thread to keep its own device stack or the device scopes\n    # won't be properly nested.\n    tf.compat.v1.get_default_graph().switch_to_thread_local()\n    for num_shards in _SHARD_NUMBERS:\n      for num_loss_partitions in _NUM_LOSS_PARTITIONS:\n        with self._single_threaded_test_session():\n          options = dict(\n              # Keep the same solution as for TestSimple: since the number of\n              # examples is multplied by num_loss_partitions, multiply also\n              # L2 by the same value.\n              symmetric_l2_regularization=num_loss_partitions,\n              symmetric_l1_regularization=0,\n              loss_type='logistic_loss',\n              num_table_shards=num_shards,\n              num_loss_partitions=num_loss_partitions)\n\n          lr = _SDCAModel(examples, variables, options)\n          tf.compat.v1.initializers.global_variables().run()\n          unregularized_loss = lr.unregularized_loss(examples)\n          loss = lr.regularized_loss(examples)\n          predictions = lr.predictions(examples)\n          self.assertAllClose(0.693147, unregularized_loss.eval())\n          self.assertAllClose(0.693147, loss.eval())\n\n          train_op = lr.minimize()\n\n          def minimize(worker_id):\n            with context.graph_mode(), self._single_threaded_test_session():\n              feed_dict = {\n                  example_ids: [\n                      str(i + worker_id * len(example_weights))\n                      for i in range(len(example_weights))\n                  ]\n              }\n              for _ in range(_MAX_ITERATIONS):\n                train_op.run(feed_dict=feed_dict)  # pylint: disable=cell-var-from-loop\n\n          threads = []\n          for worker_id in range(num_loss_partitions):\n            threads.append(threading.Thread(target=minimize, args=(worker_id,)))\n            threads[-1].start()\n\n          for t in threads:\n            t.join()\n          lr.update_weights(train_op).run(feed_dict={\n              example_ids: [str(i) for i in range(len(example_weights))]\n          })\n\n          # Test only the unregularized loss because the optimal value of the\n          # regularized loss depends on num_loss_partitions.\n          self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.02)\n          predicted_labels = get_binary_predictions_for_logistic(predictions)\n          self.assertAllEqual([0, 1], predicted_labels.eval())\n          self.assertNear(0.0, lr.approximate_duality_gap().eval(), 0.02)\n\n  def testSimpleNoL2(self):\n    # L2 regularization of SDCA should be positive.\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(1, 1, 1)\n      options = dict(\n          symmetric_l2_regularization=0,\n          symmetric_l1_regularization=0,\n          num_table_shards=1,\n          loss_type='logistic_loss')\n\n      with self.assertRaises(ValueError):\n        _SDCAModel(examples, variables, options)\n\n  def testSomeUnweightedExamples(self):\n    # Setup test data with 4 examples, but should produce the same\n    # results as testSimple.\n    example_protos = [\n        # Will be used.\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        # Will be ignored.\n        make_example_proto({\n            'age': [1],\n            'gender': [0]\n        }, 0),\n        # Will be used.\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n        # Will be ignored.\n        make_example_proto({\n            'age': [1],\n            'gender': [0]\n        }, 1),\n    ]\n    example_weights = [1.0, 0.0, 1.0, 0.0]\n    for num_shards in _SHARD_NUMBERS:\n      with self._single_threaded_test_session():\n        # Only use examples 0 and 2\n        examples = make_example_dict(example_protos, example_weights)\n        variables = make_variable_dict(1, 1, num_shards)\n        options = dict(\n            symmetric_l2_regularization=1,\n            symmetric_l1_regularization=0,\n            num_table_shards=num_shards,\n            loss_type='logistic_loss')\n\n        lr = _SDCAModel(examples, variables, options)\n        tf.compat.v1.initializers.global_variables().run()\n        unregularized_loss = lr.unregularized_loss(examples)\n        loss = lr.regularized_loss(examples)\n        predictions = lr.predictions(examples)\n        train_op = lr.minimize()\n        for _ in range(_MAX_ITERATIONS):\n          train_op.run()\n        lr.update_weights(train_op).run()\n\n        self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05)\n        self.assertAllClose(0.525457, loss.eval(), atol=0.01)\n        predicted_labels = get_binary_predictions_for_logistic(predictions)\n        self.assertAllClose([0, 1, 1, 1], predicted_labels.eval())\n        self.assertAllClose(\n            0.0, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)\n\n  def testFractionalExampleLabel(self):\n    # Setup test data with 1 positive, and 1 mostly-negative example.\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0.1),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 0.9),\n    ]\n    example_weights = [1.0, 1.0]\n    for num_shards in _SHARD_NUMBERS:\n      with self._single_threaded_test_session():\n        examples = make_example_dict(example_protos, example_weights)\n        variables = make_variable_dict(1, 1, num_shards)\n        options = dict(\n            symmetric_l2_regularization=1,\n            symmetric_l1_regularization=0,\n            num_table_shards=num_shards,\n            loss_type='logistic_loss')\n\n        lr = _SDCAModel(examples, variables, options)\n        tf.compat.v1.initializers.global_variables().run()\n        with self.assertRaisesOpError(\n            'Only labels of 0.0 or 1.0 are supported right now.'):\n          lr.minimize().run()\n\n  def testImbalanced(self):\n    # Setup test data with 1 positive, and 3 negative examples.\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [2],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [3],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0, 1.0, 1.0]\n    for num_shards in _SHARD_NUMBERS:\n      with self._single_threaded_test_session():\n        examples = make_example_dict(example_protos, example_weights)\n        variables = make_variable_dict(3, 1, num_shards)\n        options = dict(\n            symmetric_l2_regularization=1,\n            symmetric_l1_regularization=0,\n            num_table_shards=num_shards,\n            loss_type='logistic_loss')\n\n        lr = _SDCAModel(examples, variables, options)\n        tf.compat.v1.initializers.global_variables().run()\n        unregularized_loss = lr.unregularized_loss(examples)\n        loss = lr.regularized_loss(examples)\n        predictions = lr.predictions(examples)\n        train_op = lr.minimize()\n        for _ in range(_MAX_ITERATIONS):\n          train_op.run()\n        lr.update_weights(train_op).run()\n\n        self.assertAllClose(\n            0.226487 + 0.102902, unregularized_loss.eval(), atol=0.08)\n        self.assertAllClose(0.328394 + 0.131364, loss.eval(), atol=0.01)\n        predicted_labels = get_binary_predictions_for_logistic(predictions)\n        self.assertAllEqual([0, 0, 0, 1], predicted_labels.eval())\n        self.assertAllClose(\n            0.0, lr.approximate_duality_gap().eval(), rtol=2e-2, atol=1e-2)\n\n  def testImbalancedWithExampleWeights(self):\n    # Setup test data with 1 positive, and 1 negative example.\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [3.0, 1.0]\n    for num_shards in _SHARD_NUMBERS:\n      with self._single_threaded_test_session():\n        examples = make_example_dict(example_protos, example_weights)\n        variables = make_variable_dict(1, 1, num_shards)\n        options = dict(\n            symmetric_l2_regularization=1,\n            symmetric_l1_regularization=0,\n            num_table_shards=num_shards,\n            loss_type='logistic_loss')\n\n        lr = _SDCAModel(examples, variables, options)\n        tf.compat.v1.initializers.global_variables().run()\n        unregularized_loss = lr.unregularized_loss(examples)\n        loss = lr.regularized_loss(examples)\n        predictions = lr.predictions(examples)\n        train_op = lr.minimize()\n        for _ in range(_MAX_ITERATIONS):\n          train_op.run()\n        lr.update_weights(train_op).run()\n\n        self.assertAllClose(0.284860, unregularized_loss.eval(), atol=0.08)\n        self.assertAllClose(0.408044, loss.eval(), atol=0.012)\n        predicted_labels = get_binary_predictions_for_logistic(predictions)\n        self.assertAllEqual([0, 1], predicted_labels.eval())\n        self.assertAllClose(\n            0.0, lr.approximate_duality_gap().eval(), rtol=2e-2, atol=1e-2)\n\n  def testInstancesOfOneClassOnly(self):\n    # Setup test data with 1 positive (ignored), and 1 negative example.\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [0]\n        }, 1),  # Shares gender with the instance above.\n    ]\n    example_weights = [1.0, 0.0]  # Second example \"omitted\" from training.\n    for num_shards in _SHARD_NUMBERS:\n      with self._single_threaded_test_session():\n        examples = make_example_dict(example_protos, example_weights)\n        variables = make_variable_dict(1, 1, num_shards)\n        options = dict(\n            symmetric_l2_regularization=1,\n            symmetric_l1_regularization=0,\n            num_table_shards=num_shards,\n            loss_type='logistic_loss')\n\n        lr = _SDCAModel(examples, variables, options)\n        tf.compat.v1.initializers.global_variables().run()\n        unregularized_loss = lr.unregularized_loss(examples)\n        loss = lr.regularized_loss(examples)\n        predictions = lr.predictions(examples)\n        train_op = lr.minimize()\n        for _ in range(_MAX_ITERATIONS):\n          train_op.run()\n        lr.update_weights(train_op).run()\n        self.assertAllClose(0.411608, unregularized_loss.eval(), atol=0.05)\n        self.assertAllClose(0.525457, loss.eval(), atol=0.01)\n        predicted_labels = get_binary_predictions_for_logistic(predictions)\n        self.assertAllEqual([0, 0], predicted_labels.eval())\n        self.assertAllClose(\n            0.01, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)\n\n  def testOutOfRangeSparseFeatures(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(0, 0)\n      options = dict(\n          symmetric_l2_regularization=1,\n          symmetric_l1_regularization=0,\n          loss_type='logistic_loss')\n\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      train_op = lr.minimize()\n      with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, 'indices.*'):\n        train_op.run()\n\n  def testOutOfRangeDenseFeatures(self):\n    with self._single_threaded_test_session():\n      examples, variables = make_dense_examples_and_variables_dicts(\n          dense_features_values=[[[1.0, 0.0], [0.0, 1.0]]],\n          weights=[20.0, 10.0],\n          labels=[1.0, 0.0])\n      # Replace with a variable of size 1 instead of 2.\n      variables['dense_features_weights'] = [\n          tf.compat.v1.Variable(tf.zeros([1], dtype=tf.dtypes.float32))\n      ]\n      options = dict(\n          symmetric_l2_regularization=1.0,\n          symmetric_l1_regularization=0,\n          loss_type='logistic_loss')\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      train_op = lr.minimize()\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          'More dense features than we have parameters for.*'):\n        train_op.run()\n\n  def testMissingFeature(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n        make_example_proto({\n            'age': [],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0, 1.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(1, 1, 1)\n      options = dict(\n          symmetric_l2_regularization=1,\n          symmetric_l1_regularization=0,\n          num_table_shards=1,\n          loss_type='logistic_loss')\n\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      unregularized_loss = lr.unregularized_loss(examples)\n      self.assertAllClose(0.693147, unregularized_loss.eval())\n      train_op = lr.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      lr.update_weights(train_op).run()\n      self.assertAllClose(\n          0.0, lr.approximate_duality_gap().eval(), rtol=1e-2, atol=1e-2)\n\n  # TODO(katsiaspis): add a test for the case when examples at the end of an\n  # epoch are repeated, since example id may be duplicated.\n\n\n# ResourceVariable only runs in graph mode\n@test_util.deprecated_graph_mode_only\nclass SdcaWithLinearLossTest(_SDCAModelTest):\n  \"\"\"SDCA optimizer test class for linear (squared) loss.\"\"\"\n\n  def testSimple(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, -10.0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 14.0),\n    ]\n    example_weights = [1.0, 1.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(1, 1)\n      options = dict(\n          symmetric_l2_regularization=1,\n          symmetric_l1_regularization=0,\n          loss_type='squared_loss')\n\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      predictions = lr.predictions(examples)\n      train_op = lr.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      lr.update_weights(train_op).run()\n\n      # Predictions should be 2/3 of label due to minimizing regularized loss:\n      #   (label - 2 * weight)^2 / 2 + L2 * 2 * weight^2\n      self.assertAllClose([-20.0 / 3.0, 28.0 / 3.0],\n                          predictions.eval(),\n                          rtol=0.005)\n      # Approximate gap should be very close to 0.0. (In fact, because the gap\n      # is only approximate, it is likely that upon convergence the duality gap\n      # can have a tiny negative value).\n      self.assertAllClose(0.0, lr.approximate_duality_gap().eval(), atol=1e-2)\n\n  def testL2Regularization(self):\n    # Setup test data\n    example_protos = [\n        # 2 identical examples\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, -10.0),\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, -10.0),\n        # 2 more identical examples\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 14.0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 14.0),\n    ]\n    example_weights = [1.0, 1.0, 1.0, 1.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(1, 1)\n      options = dict(\n          symmetric_l2_regularization=16,\n          symmetric_l1_regularization=0,\n          loss_type='squared_loss')\n\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      predictions = lr.predictions(examples)\n\n      train_op = lr.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      lr.update_weights(train_op).run()\n\n      # Predictions should be 1/5 of label due to minimizing regularized loss:\n      #   (label - 2 * weight)^2 + L2 * 16 * weight^2\n      optimal1 = -10.0 / 5.0\n      optimal2 = 14.0 / 5.0\n      self.assertAllClose([optimal1, optimal1, optimal2, optimal2],\n                          predictions.eval(),\n                          rtol=0.01)\n\n  def testL1Regularization(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, -10.0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 14.0),\n    ]\n    example_weights = [1.0, 1.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(1, 1)\n      options = dict(\n          symmetric_l2_regularization=1.0,\n          symmetric_l1_regularization=4.0,\n          loss_type='squared_loss')\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      prediction = lr.predictions(examples)\n      loss = lr.regularized_loss(examples)\n\n      train_op = lr.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      lr.update_weights(train_op).run()\n\n      # Predictions should be -4, 20/3 due to minimizing regularized loss:\n      #   (label - 2 * weight)^2 / 2 + L2 * 2 * weight^2 + L1 * 4 * weight\n      self.assertAllClose([-4.0, 20.0 / 3.0], prediction.eval(), rtol=0.08)\n\n      # Loss should be the sum of the regularized loss value from above per\n      # example after plugging in the optimal weights.\n      self.assertAllClose(308.0 / 6.0, loss.eval(), atol=0.01)\n\n  def testFeatureValues(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, -10.0, -2.0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 14.0, 2.0),\n    ]\n    example_weights = [5.0, 3.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n\n      variables = make_variable_dict(1, 1)\n      options = dict(\n          symmetric_l2_regularization=1,\n          symmetric_l1_regularization=0,\n          loss_type='squared_loss')\n\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      predictions = lr.predictions(examples)\n\n      train_op = lr.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      lr.update_weights(train_op).run()\n\n      # There are 4 (sparse) variable weights to be learned. 2 for age and 2 for\n      # gender. Let w_1, w_2 be age weights, w_3, w_4 be gender weights, y_1,\n      # y_2 be the labels for examples 1 and 2 respectively and s_1, s_2 the\n      # corresponding *example* weights. With the given feature values, the loss\n      # function is given by:\n      # s_1/2(y_1 + 2w_1 + 2w_3)^2 + s_2/2(y_2 - 2w_2 - 2w_4)^2\n      # + \\lambda/2 (w_1^2 + w_2^2 + w_3^2 + w_4^2). Solving for the optimal, it\n      # can be verified that:\n      # w_1* = w_3* = -2.0 s_1 y_1/(\\lambda + 8 s_1) and\n      # w_2* = w_4* = 2 \\cdot s_2 y_2/(\\lambda + 8 s_2). Equivalently, due to\n      # regularization and example weights, the predictions are within:\n      # 8 \\cdot s_i /(\\lambda + 8 \\cdot s_i) of the labels.\n      self.assertAllClose([-10 * 40.0 / 41.0, 14.0 * 24 / 25.0],\n                          predictions.eval(),\n                          atol=0.01)\n\n  def testDenseFeaturesWithDefaultWeights(self):\n    with self._single_threaded_test_session():\n      examples, variables = make_dense_examples_and_variables_dicts(\n          dense_features_values=[[[1.0], [0.0]], [0.0, 1.0]],\n          weights=[1.0, 1.0],\n          labels=[10.0, -5.0])\n      options = dict(\n          symmetric_l2_regularization=1.0,\n          symmetric_l1_regularization=0,\n          loss_type='squared_loss')\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      predictions = lr.predictions(examples)\n\n      train_op = lr.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      lr.update_weights(train_op).run()\n\n      # The loss function for these particular features is given by:\n      # 1/2(label_1-w_1)^2 + 1/2(label_2-w_2)^2 + \\lambda/2 (w_1^2 + w_2^2). So,\n      # differentiating wrt to w_1, w_2 yields the following optimal values:\n      # w_1* = label_1/(\\lambda + 1)= 10/2, w_2* =label_2/(\\lambda + 1)= -5/2.\n      # In this case the (unnormalized regularized) loss will be:\n      # 1/2(10-5)^2 + 1/2(5-5/2)^2 + 1/2(5^2 + (5/2)^2) = 125.0/4. The actual\n      # loss should be further normalized by the sum of example weights.\n      self.assertAllClose([5.0, -2.5], predictions.eval(), rtol=0.01)\n      loss = lr.regularized_loss(examples)\n      self.assertAllClose(125.0 / 8.0, loss.eval(), atol=0.01)\n\n  def testDenseFeaturesWithArbitraryWeights(self):\n    with self._single_threaded_test_session():\n      examples, variables = make_dense_examples_and_variables_dicts(\n          dense_features_values=[[[1.0, 0.0], [0.0, 1.0]]],\n          weights=[20.0, 10.0],\n          labels=[10.0, -5.0])\n      options = dict(\n          symmetric_l2_regularization=5.0,\n          symmetric_l1_regularization=0,\n          loss_type='squared_loss')\n      lr = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      predictions = lr.predictions(examples)\n\n      train_op = lr.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      lr.update_weights(train_op).run()\n\n      # The loss function for these particular features is given by:\n      # 1/2 s_1 (label_1-w_1)^2 + 1/2 s_2(label_2-w_2)^2 +\n      # \\lambda/2 (w_1^2 + w_2^2) where s_1, s_2 are the *example weights. It\n      # turns out that the optimal (variable) weights are given by:\n      # w_1* = label_1 \\cdot s_1/(\\lambda + s_1)= 8.0 and\n      # w_2* =label_2 \\cdot s_2/(\\lambda + s_2)= -10/3.\n      # In this case the (unnormalized regularized) loss will be:\n      # s_1/2(8-10)^2 + s_2/2(5-10/3)^2 + 5.0/2(8^2 + (10/3)^2) = 2175.0/9. The\n      # actual loss should be further normalized by the sum of example weights.\n      self.assertAllClose([8.0, -10.0 / 3], predictions.eval(), rtol=0.01)\n      loss = lr.regularized_loss(examples)\n      self.assertAllClose(2175.0 / 270.0, loss.eval(), atol=0.01)\n\n\n# ResourceVariable only runs in graph mode\n@test_util.deprecated_graph_mode_only\nclass SdcaWithHingeLossTest(_SDCAModelTest):\n  \"\"\"SDCA optimizer test class for hinge loss.\"\"\"\n\n  def testSimple(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(1, 1)\n      options = dict(\n          symmetric_l2_regularization=1.0,\n          symmetric_l1_regularization=0,\n          loss_type='hinge_loss')\n      model = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n\n      # Before minimization, the weights default to zero. There is no loss due\n      # to regularization, only unregularized loss which is 0.5 * (1+1) = 1.0.\n      predictions = model.predictions(examples)\n      self.assertAllClose([0.0, 0.0], predictions.eval())\n      unregularized_loss = model.unregularized_loss(examples)\n      regularized_loss = model.regularized_loss(examples)\n      self.assertAllClose(1.0, unregularized_loss.eval())\n      self.assertAllClose(1.0, regularized_loss.eval())\n\n      # After minimization, the model separates perfectly the data points. There\n      # are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender (say w3\n      # and w4). Solving the system w1 + w3 = 1.0, w2 + w4 = -1.0 and minimizing\n      # wrt to \\|\\vec{w}\\|_2, gives w1=w3=1/2 and w2=w4=-1/2. This gives 0.0\n      # unregularized loss and 0.25 L2 loss.\n      train_op = model.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      model.update_weights(train_op).run()\n\n      binary_predictions = get_binary_predictions_for_hinge(predictions)\n      self.assertAllEqual([-1.0, 1.0], predictions.eval())\n      self.assertAllEqual([0, 1], binary_predictions.eval())\n      self.assertAllClose(0.0, unregularized_loss.eval())\n      self.assertAllClose(0.25, regularized_loss.eval(), atol=0.05)\n\n  def testDenseFeaturesPerfectlySeparable(self):\n    with self._single_threaded_test_session():\n      examples, variables = make_dense_examples_and_variables_dicts(\n          dense_features_values=[[1.0, 1.0], [1.0, -1.0]],\n          weights=[1.0, 1.0],\n          labels=[1.0, 0.0])\n      options = dict(\n          symmetric_l2_regularization=1.0,\n          symmetric_l1_regularization=0,\n          loss_type='hinge_loss')\n      model = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      predictions = model.predictions(examples)\n      binary_predictions = get_binary_predictions_for_hinge(predictions)\n\n      train_op = model.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      model.update_weights(train_op).run()\n\n      self.assertAllClose([1.0, -1.0], predictions.eval(), atol=0.05)\n      self.assertAllEqual([1, 0], binary_predictions.eval())\n\n      # (1.0, 1.0) and (1.0, -1.0) are perfectly separable by x-axis (that is,\n      # the SVM's functional margin >=1), so the unregularized loss is ~0.0.\n      # There is only loss due to l2-regularization. For these datapoints, it\n      # turns out that w_1~=0.0 and w_2~=1.0 which means that l2 loss is ~0.25.\n      unregularized_loss = model.unregularized_loss(examples)\n      regularized_loss = model.regularized_loss(examples)\n      self.assertAllClose(0.0, unregularized_loss.eval(), atol=0.02)\n      self.assertAllClose(0.25, regularized_loss.eval(), atol=0.02)\n\n  def testDenseFeaturesSeparableWithinMargins(self):\n    with self._single_threaded_test_session():\n      examples, variables = make_dense_examples_and_variables_dicts(\n          dense_features_values=[[[1.0, 0.5], [1.0, -0.5]]],\n          weights=[1.0, 1.0],\n          labels=[1.0, 0.0])\n      options = dict(\n          symmetric_l2_regularization=1.0,\n          symmetric_l1_regularization=0,\n          loss_type='hinge_loss')\n      model = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      predictions = model.predictions(examples)\n      binary_predictions = get_binary_predictions_for_hinge(predictions)\n\n      train_op = model.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      model.update_weights(train_op).run()\n\n      # (1.0, 0.5) and (1.0, -0.5) are separable by x-axis but the datapoints\n      # are within the margins so there is unregularized loss (1/2 per example).\n      # For these datapoints, optimal weights are w_1~=0.0 and w_2~=1.0 which\n      # gives an L2 loss of ~0.25.\n      self.assertAllClose([0.5, -0.5], predictions.eval(), rtol=0.05)\n      self.assertAllEqual([1, 0], binary_predictions.eval())\n      unregularized_loss = model.unregularized_loss(examples)\n      regularized_loss = model.regularized_loss(examples)\n      self.assertAllClose(0.5, unregularized_loss.eval(), atol=0.02)\n      self.assertAllClose(0.75, regularized_loss.eval(), atol=0.02)\n\n  def testDenseFeaturesWeightedExamples(self):\n    with self._single_threaded_test_session():\n      examples, variables = make_dense_examples_and_variables_dicts(\n          dense_features_values=[[[1.0], [1.0]], [[0.5], [-0.5]]],\n          weights=[3.0, 1.0],\n          labels=[1.0, 0.0])\n      options = dict(\n          symmetric_l2_regularization=1.0,\n          symmetric_l1_regularization=0,\n          loss_type='hinge_loss')\n      model = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n      predictions = model.predictions(examples)\n      binary_predictions = get_binary_predictions_for_hinge(predictions)\n      train_op = model.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      model.update_weights(train_op).run()\n\n      # Point (1.0, 0.5) has higher weight than (1.0, -0.5) so the model will\n      # try to increase the margin from (1.0, 0.5). Due to regularization,\n      # (1.0, -0.5) will be within the margin. For these points and example\n      # weights, the optimal weights are w_1~=0.4 and w_2~=1.2 which give an L2\n      # loss of 0.5 * 0.25 * 0.25 * 1.6 = 0.2. The binary predictions will be\n      # correct, but the boundary will be much closer to the 2nd point than the\n      # first one.\n      self.assertAllClose([1.0, -0.2], predictions.eval(), atol=0.05)\n      self.assertAllEqual([1, 0], binary_predictions.eval())\n      unregularized_loss = model.unregularized_loss(examples)\n      regularized_loss = model.regularized_loss(examples)\n      self.assertAllClose(0.2, unregularized_loss.eval(), atol=0.02)\n      self.assertAllClose(0.4, regularized_loss.eval(), atol=0.02)\n\n\n# ResourceVariable only runs in graph mode\n@test_util.deprecated_graph_mode_only\nclass SdcaWithSmoothHingeLossTest(_SDCAModelTest):\n  \"\"\"SDCA optimizer test class for smooth hinge loss.\"\"\"\n\n  def testSimple(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 1),\n    ]\n    example_weights = [1.0, 1.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(1, 1)\n      options = dict(\n          symmetric_l2_regularization=1.0,\n          symmetric_l1_regularization=0,\n          loss_type='smooth_hinge_loss')\n      model = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n\n      # Before minimization, the weights default to zero. There is no loss due\n      # to regularization, only unregularized loss which is 0.5 * (1+1) = 1.0.\n      predictions = model.predictions(examples)\n      self.assertAllClose([0.0, 0.0], predictions.eval())\n      unregularized_loss = model.unregularized_loss(examples)\n      regularized_loss = model.regularized_loss(examples)\n      self.assertAllClose(1.0, unregularized_loss.eval())\n      self.assertAllClose(1.0, regularized_loss.eval())\n\n      # After minimization, the model separates perfectly the data points. There\n      # are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender (say w3\n      # and w4). The minimization leads to w1=w3=1/3 and w2=w4=-1/3. This gives\n      # an unregularized hinge loss of 0.33 and a 0.11 L2 loss\n      train_op = model.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      model.update_weights(train_op).run()\n\n      binary_predictions = get_binary_predictions_for_hinge(predictions)\n      self.assertAllClose([-0.67, 0.67], predictions.eval(), atol=0.05)\n      self.assertAllEqual([0, 1], binary_predictions.eval())\n      self.assertAllClose(0.33, unregularized_loss.eval(), atol=0.02)\n      self.assertAllClose(0.44, regularized_loss.eval(), atol=0.02)\n\n\n# ResourceVariable only runs in graph mode\n@test_util.deprecated_graph_mode_only\nclass SdcaWithPoissonLossTest(_SDCAModelTest):\n  \"\"\"SDCA optimizer test class for poisson loss.\"\"\"\n\n  def testSimple(self):\n    # Setup test data\n    example_protos = [\n        make_example_proto({\n            'age': [0],\n            'gender': [0]\n        }, 0),\n        make_example_proto({\n            'age': [1],\n            'gender': [1]\n        }, 2),\n    ]\n    example_weights = [100.0, 100.0]\n    with self._single_threaded_test_session():\n      examples = make_example_dict(example_protos, example_weights)\n      variables = make_variable_dict(1, 1)\n      options = dict(\n          symmetric_l2_regularization=1.0,\n          symmetric_l1_regularization=0,\n          loss_type='poisson_loss')\n      model = _SDCAModel(examples, variables, options)\n      tf.compat.v1.initializers.global_variables().run()\n\n      # Before minimization, the weights default to zero. There is no loss due\n      # to regularization, only unregularized loss which is 1 for each example.\n      predictions = model.predictions(examples)\n      self.assertAllClose([1.0, 1.0], predictions.eval())\n      unregularized_loss = model.unregularized_loss(examples)\n      regularized_loss = model.regularized_loss(examples)\n      approximate_duality_gap = model.approximate_duality_gap()\n      self.assertAllClose(1.0, unregularized_loss.eval())\n      self.assertAllClose(1.0, regularized_loss.eval())\n\n      # There are 4 sparse weights: 2 for age (say w1, w2) and 2 for gender\n      # (say w3 and w4). The minimization leads to:\n      # w1=w3=-1.96487, argmin of 100*(exp(2*w)-2*w*0)+w**2.\n      # w2=w4=0.345708, argmin of 100*(exp(2*w)-2*w*2)+w**2.\n      # This gives an unregularized loss of .3167 and .3366 with regularization.\n      train_op = model.minimize()\n      for _ in range(_MAX_ITERATIONS):\n        train_op.run()\n      model.update_weights(train_op).run()\n\n      self.assertAllClose([0.0196, 1.9965], predictions.eval(), atol=1e-4)\n      self.assertAllClose(0.3167, unregularized_loss.eval(), atol=1e-4)\n      self.assertAllClose(0.3366, regularized_loss.eval(), atol=1e-4)\n      self.assertAllClose(0., approximate_duality_gap.eval(), atol=1e-6)\n\n\nclass SdcaFprintTest(_SDCAModelTest):\n  \"\"\"Tests for the SdcaFprint op.\n\n  This is one way of enforcing the platform-agnostic nature of SdcaFprint.\n  Basically we are checking against exact values and this test could be running\n  across different platforms. Note that it is fine for expected values to change\n  in the future, if the implementation of SdcaFprint changes (ie this is *not* a\n  frozen test).\n  \"\"\"\n\n  def testFprint(self):\n    with self._single_threaded_test_session():\n      in_data = tf.constant(['abc', 'very looooooong string', 'def'])\n      out_data = tf.compat.v1.train.sdca_fprint(in_data)\n      self.assertAllEqual([[4143508125394299908, -6879828354153669051],\n                           [5849691694103072671, -4874542629849009556],\n                           [603227410218889250, 8762207001949257490]],\n                          self.evaluate(out_data))\n\n\nclass _SparseFeatureColumnTest(TensorFlowTestCase):\n  \"\"\"Tests for _SparseFeatureColumn.\"\"\"\n\n  def testBasic(self):\n    expected_example_indices = [1, 1, 1, 2]\n    expected_feature_indices = [0, 1, 2, 0]\n    sfc = _SparseFeatureColumn(expected_example_indices,\n                               expected_feature_indices, None)\n    self.assertIsInstance(sfc.example_indices, tf.Tensor)\n    self.assertIsInstance(sfc.feature_indices, tf.Tensor)\n    self.assertEqual(sfc.feature_values, None)\n    with self.cached_session():\n      self.assertAllEqual(expected_example_indices,\n                          self.evaluate(sfc.example_indices))\n      self.assertAllEqual(expected_feature_indices,\n                          self.evaluate(sfc.feature_indices))\n    expected_feature_values = [1.0, 2.0, 3.0, 4.0]\n    sfc = _SparseFeatureColumn([1, 1, 1, 2], [0, 1, 2, 0],\n                               expected_feature_values)\n    with self.cached_session():\n      self.assertAllEqual(expected_feature_values,\n                          self.evaluate(sfc.feature_values))\n\n\nif __name__ == '__main__':\n  googletest.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_optimizer/python/utils/sharded_mutable_dense_hashtable.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Sharded mutable dense hash table.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\n\nfrom six.moves import range\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import gen_lookup_ops\nfrom tensorflow.python.ops import lookup_ops\nfrom tensorflow.python.training.saver import BaseSaverBuilder\nfrom tensorflow.python.checkpoint import saveable_compat\n\n\n@saveable_compat.legacy_saveable_name(\"table\")\nclass _MutableDenseHashTable(lookup_ops.LookupInterface):\n  \"\"\"Copy of tf.contrib.lookup.MutableDenseHashTable.\"\"\"\n\n  # TODO(b/118148303): Swap this with the core version\n  def __init__(self,\n               key_dtype,\n               value_dtype,\n               default_value,\n               empty_key,\n               deleted_key,\n               initial_num_buckets=None,\n               shared_name=None,\n               name=\"MutableDenseHashTable\",\n               checkpoint=True):\n    \"\"\"Creates an empty `_MutableDenseHashTable` object.\n\n    Creates a table, the type of its keys and values are specified by key_dtype\n    and value_dtype, respectively.\n\n    Args:\n      key_dtype: the type of the key tensors.\n      value_dtype: the type of the value tensors.\n      default_value: The value to use if a key is missing in the table.\n      empty_key: the key to use to represent empty buckets internally. Must not\n        be used in insert, remove or lookup operations.\n      deleted_key: the key to use to represent deleted buckets internally. Must\n        not be used in insert, remove or lookup operations and be different from\n        the empty_key.\n      initial_num_buckets: the initial number of buckets.\n      shared_name: If non-empty, this table will be shared under the given name\n        across multiple sessions.\n      name: A name for the operation (optional).\n      checkpoint: if True, the contents of the table are saved to and restored\n        from checkpoints. If `shared_name` is empty for a checkpointed table, it\n        is shared using the table node name.\n\n    Returns:\n      A `_MutableDenseHashTable` object.\n\n    Raises:\n      ValueError: If checkpoint is True and no name was specified.\n    \"\"\"\n    self._default_value = ops.convert_to_tensor(\n        default_value, dtype=value_dtype, name=\"default_value\")\n    self._key_dtype = key_dtype\n    self._value_dtype = value_dtype\n    self._initial_num_buckets = initial_num_buckets\n    self._value_shape = self._default_value.get_shape()\n    self._checkpoint = checkpoint\n    self._name = name\n\n    self._empty_key = ops.convert_to_tensor(\n        empty_key, dtype=key_dtype, name=\"empty_key\")\n    self._deleted_key = ops.convert_to_tensor(\n        deleted_key, dtype=key_dtype, name=\"deleted_key\")\n    if tf.executing_eagerly() and shared_name is None:\n      # TODO(allenl): This will leak memory due to kernel caching by the\n      # shared_name attribute value (but is better than the alternative of\n      # sharing everything by default when executing eagerly; hopefully creating\n      # tables in a loop is uncommon).\n      shared_name = \"table_%d\" % (ops.uid(),)\n    self._shared_name = shared_name\n    super(_MutableDenseHashTable, self).__init__(key_dtype, value_dtype)\n\n    self._resource_handle = self._create_resource()\n    if checkpoint:\n      saveable = _MutableDenseHashTable._Saveable(self, name)\n      if not tf.executing_eagerly():\n        tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.SAVEABLE_OBJECTS,\n                                       saveable)\n\n  def _create_resource(self):\n    # The table must be shared if checkpointing is requested for multi-worker\n    # training to work correctly. Use the node name if no shared_name has been\n    # explicitly specified.\n    use_node_name_sharing = self._checkpoint and self._shared_name is None\n    table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(\n        empty_key=self._empty_key,\n        deleted_key=self._deleted_key,\n        shared_name=self._shared_name,\n        use_node_name_sharing=use_node_name_sharing,\n        value_dtype=self._value_dtype,\n        value_shape=self._value_shape,\n        initial_num_buckets=self._initial_num_buckets,\n        name=self._name)\n    if tf.executing_eagerly():\n      self._table_name = None\n    else:\n      self._table_name = table_ref.op.name.split(\"/\")[-1]\n    return table_ref\n\n  @property\n  def name(self):\n    return self._table_name\n\n  def size(self, name=None):\n    \"\"\"Compute the number of elements in this table.\n\n    Args:\n      name: A name for the operation (optional).\n\n    Returns:\n      A scalar tensor containing the number of elements in this table.\n    \"\"\"\n    with ops.name_scope(name, \"%s_Size\" % self.name,\n                        [self.resource_handle]) as name:\n      with ops.colocate_with(self.resource_handle):\n        return gen_lookup_ops.lookup_table_size_v2(\n            self.resource_handle, name=name)\n\n  def lookup(self, keys, name=None):\n    \"\"\"Looks up `keys` in a table, outputs the corresponding values.\n\n    The `default_value` is used for keys not present in the table.\n\n    Args:\n      keys: Keys to look up. Can be a tensor of any shape. Must match the\n        table's key_dtype.\n      name: A name for the operation (optional).\n\n    Returns:\n      A tensor containing the values in the same shape as `keys` using the\n        table's value type.\n\n    Raises:\n      TypeError: when `keys` do not match the table data types.\n    \"\"\"\n    with ops.name_scope(name, \"%s_lookup_table_find\" % self.name,\n                        [self.resource_handle, keys]) as name:\n      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name=\"keys\")\n      with ops.colocate_with(self.resource_handle):\n        values = gen_lookup_ops.lookup_table_find_v2(\n            self.resource_handle, keys, self._default_value, name=name)\n\n    return values\n\n  def insert(self, keys, values, name=None):\n    \"\"\"Associates `keys` with `values`.\n\n    Args:\n      keys: Keys to insert. Can be a tensor of any shape. Must match the table's\n        key type.\n      values: Values to be associated with keys. Must be a tensor of the same\n        shape as `keys` and match the table's value type.\n      name: A name for the operation (optional).\n\n    Returns:\n      The created Operation.\n\n    Raises:\n      TypeError: when `keys` or `values` doesn't match the table data\n        types.\n    \"\"\"\n    with ops.name_scope(name, \"%s_lookup_table_insert\" % self.name,\n                        [self.resource_handle, keys, values]) as name:\n      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name=\"keys\")\n      values = ops.convert_to_tensor(\n          values, dtype=self._value_dtype, name=\"values\")\n      with ops.colocate_with(self.resource_handle):\n        op = gen_lookup_ops.lookup_table_insert_v2(\n            self.resource_handle, keys, values, name=name)\n      return op\n\n  def export(self, name=None):\n    \"\"\"Returns tensors of all keys and values in the table.\n\n    Args:\n      name: A name for the operation (optional).\n\n    Returns:\n      A pair of tensors with the first tensor containing all keys and the\n        second tensors containing all values in the table.\n    \"\"\"\n    with ops.name_scope(name, \"%s_lookup_table_export_values\" % self.name,\n                        [self.resource_handle]) as name:\n      with ops.colocate_with(self.resource_handle):\n        exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(\n            self.resource_handle, self._key_dtype, self._value_dtype, name=name)\n\n    return exported_keys, exported_values\n\n  def _serialize_to_tensors(self):\n    tesnors = self.export()\n    return {\"-keys\": tesnors[0], \"-values\": tesnors[1]}\n\n  def _restore_from_tensors(self, restored_tensors):\n    with ops.colocate_with(self.resource_handle):\n      return gen_lookup_ops.lookup_table_import_v2(self.resource_handle,\n                                                   restored_tensors[\"-keys\"],\n                                                   restored_tensors[\"-values\"])\n\n  class _Saveable(BaseSaverBuilder.SaveableObject):\n    \"\"\"SaveableObject implementation for _MutableDenseHashTable.\"\"\"\n\n    def __init__(self, table, name):\n      tensors = table.export()\n      specs = [\n          BaseSaverBuilder.SaveSpec(tensors[0], \"\", name + \"-keys\"),\n          BaseSaverBuilder.SaveSpec(tensors[1], \"\", name + \"-values\")\n      ]\n      # pylint: disable=protected-access\n      super(_MutableDenseHashTable._Saveable, self).__init__(table, specs, name)\n\n    def restore(self, restored_tensors, restored_shapes):\n      del restored_shapes  # unused\n      # pylint: disable=protected-access\n      with ops.colocate_with(self.op.resource_handle):\n        return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,\n                                                     restored_tensors[0],\n                                                     restored_tensors[1])\n\n\n# TODO(rohanj): This should subclass Checkpointable and implement\n# _gather_saveables_for_checkpoint.\nclass _ShardedMutableDenseHashTable(object):\n  \"\"\"A sharded version of _MutableDenseHashTable.\n\n  It is designed to be interface compatible with LookupInterface and\n  MutableDenseHashTable, with the exception of the export method, which is\n  replaced by an export_sharded method.\n\n  The _ShardedMutableDenseHashTable keeps `num_shards` _MutableDenseHashTable\n  internally. The shard is computed via the modulo operation on the key.\n  \"\"\"\n\n  def __init__(self,\n               key_dtype,\n               value_dtype,\n               default_value,\n               empty_key,\n               deleted_key,\n               num_shards=1,\n               checkpoint=True,\n               name=\"ShardedMutableHashTable\"):\n    self._key_dtype = key_dtype\n    self._value_dtype = value_dtype\n    with ops.name_scope(name, \"sharded_mutable_hash_table\") as scope:\n      table_shards = []\n      for i in range(num_shards):\n        self._table_name = scope\n        table_shards.append(\n            _MutableDenseHashTable(\n                key_dtype=key_dtype,\n                value_dtype=value_dtype,\n                default_value=default_value,\n                empty_key=empty_key,\n                deleted_key=deleted_key,\n                checkpoint=checkpoint,\n                name=\"%s-%d-of-%d\" % (name, i + 1, num_shards)))\n      self._table_shards = table_shards\n      # TODO(andreasst): add a value_shape() method to LookupInterface\n      # pylint: disable=protected-access\n      self._value_shape = self._table_shards[0]._value_shape\n      # pylint: enable=protected-access\n\n  @property\n  def name(self):\n    return self._table_name\n\n  @property\n  def _num_shards(self):\n    return len(self._table_shards)\n\n  @property\n  def table_shards(self):\n    return self._table_shards\n\n  def size(self, name=None):\n    with ops.name_scope(name, \"sharded_mutable_hash_table_size\"):\n      sizes = [self._table_shards[i].size() for i in range(self._num_shards)]\n      return tf.math.add_n(sizes)\n\n  def _shard_indices(self, keys):\n    key_shape = keys.get_shape()\n    if key_shape.ndims > 1:\n      # If keys are a matrix (i.e. a single key is a vector), we use the first\n      # element of each key vector to determine the shard.\n      keys = tf.reshape(tf.slice(keys, [0, 0], [-1, 1]), [-1])\n    indices = tf.math.floormod(tf.math.abs(keys), self._num_shards)\n    return tf.cast(indices, tf.dtypes.int32)\n\n  def _check_keys(self, keys):\n    if keys.get_shape().ndims != 1 and keys.get_shape().ndims != 2:\n      raise ValueError(\"Expected a vector or matrix for keys, got %s.\" %\n                       keys.get_shape())\n\n  def lookup(self, keys, name=None):\n    \"\"\"Looks up `keys` in a table, outputs the corresponding values.\"\"\"\n    if keys.dtype.base_dtype != self._key_dtype:\n      raise TypeError(\"Signature mismatch. Keys must be dtype %s, got %s.\" %\n                      (self._key_dtype, keys.dtype))\n    self._check_keys(keys)\n    num_shards = self._num_shards\n    if num_shards == 1:\n      return self._table_shards[0].lookup(keys, name=name)\n\n    shard_indices = self._shard_indices(keys)\n    key_shards = tf.dynamic_partition(keys, shard_indices, num_shards)\n    value_shards = [\n        self._table_shards[i].lookup(key_shards[i], name=name)\n        for i in range(num_shards)\n    ]\n\n    num_keys = tf.compat.v1.shape(keys)[0]\n    original_indices = tf.range(num_keys)\n    partitioned_indices = tf.dynamic_partition(original_indices, shard_indices,\n                                               num_shards)\n    return tf.dynamic_stitch(partitioned_indices, value_shards)\n\n  def insert(self, keys, values, name=None):\n    \"\"\"Inserts `keys` in a table.\"\"\"\n    self._check_keys(keys)\n    num_shards = self._num_shards\n    if num_shards == 1:\n      return self._table_shards[0].insert(keys, values, name=name)\n\n    shard_indices = self._shard_indices(keys)\n    key_shards = tf.dynamic_partition(keys, shard_indices, num_shards)\n    value_shards = tf.dynamic_partition(values, shard_indices, num_shards)\n    return_values = [\n        self._table_shards[i].insert(key_shards[i], value_shards[i], name=name)\n        for i in range(num_shards)\n    ]\n\n    return tf.group(*return_values)\n\n  def export_sharded(self, name=None):\n    \"\"\"Returns lists of the keys and values tensors in the sharded table.\n\n    Args:\n      name: name of the table.\n\n    Returns:\n      A pair of lists with the first list containing the key tensors and the\n        second list containing the value tensors from each shard.\n    \"\"\"\n    keys_list = []\n    values_list = []\n    for table_shard in self._table_shards:\n      exported_keys, exported_values = table_shard.export(name=name)\n      keys_list.append(exported_keys)\n      values_list.append(exported_values)\n    return keys_list, values_list\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_optimizer/python/utils/sharded_mutable_dense_hashtable_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for sharded_mutable_dense_hashtable.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow.python.platform import googletest\nfrom tensorflow_estimator.python.estimator.canned.linear_optimizer.python.utils.sharded_mutable_dense_hashtable import _ShardedMutableDenseHashTable\n\n\nclass _ShardedMutableDenseHashTableTest(tf.test.TestCase):\n  \"\"\"Tests for the ShardedMutableHashTable class.\"\"\"\n\n  def testShardedMutableHashTable(self):\n    for num_shards in [1, 3, 10]:\n      with self.cached_session():\n        default_val = -1\n        empty_key = 0\n        deleted_key = -1\n        keys = tf.constant([11, 12, 13], tf.dtypes.int64)\n        values = tf.constant([0, 1, 2], tf.dtypes.int64)\n        table = _ShardedMutableDenseHashTable(\n            tf.dtypes.int64,\n            tf.dtypes.int64,\n            default_val,\n            empty_key,\n            deleted_key,\n            num_shards=num_shards)\n        self.assertAllEqual(0, self.evaluate(table.size()))\n\n        self.evaluate(table.insert(keys, values))\n        self.assertAllEqual(3, self.evaluate(table.size()))\n\n        input_string = tf.constant([11, 12, 14], tf.dtypes.int64)\n        output = table.lookup(input_string)\n        self.assertAllEqual([3], output.get_shape())\n        self.assertAllEqual([0, 1, -1], self.evaluate(output))\n\n  def testShardedMutableHashTableVectors(self):\n    for num_shards in [1, 3, 10]:\n      with self.cached_session():\n        default_val = [-0.1, 0.2]\n        empty_key = [0, 1]\n        deleted_key = [1, 0]\n        keys = tf.constant([[11, 12], [13, 14], [15, 16]], tf.dtypes.int64)\n        values = tf.constant([[0.5, 0.6], [1.5, 1.6], [2.5, 2.6]],\n                             tf.dtypes.float32)\n        table = _ShardedMutableDenseHashTable(\n            tf.dtypes.int64,\n            tf.dtypes.float32,\n            default_val,\n            empty_key,\n            deleted_key,\n            num_shards=num_shards)\n        self.assertAllEqual(0, self.evaluate(table.size()))\n\n        self.evaluate(table.insert(keys, values))\n        self.assertAllEqual(3, self.evaluate(table.size()))\n\n        input_string = tf.constant([[11, 12], [13, 14], [11, 14]],\n                                   tf.dtypes.int64)\n        output = table.lookup(input_string)\n        self.assertAllEqual([3, 2], output.get_shape())\n        self.assertAllClose([[0.5, 0.6], [1.5, 1.6], [-0.1, 0.2]],\n                            self.evaluate(output))\n\n  def testExportSharded(self):\n    with self.cached_session():\n      empty_key = -2\n      deleted_key = -3\n      default_val = -1\n      num_shards = 2\n      keys = tf.constant([10, 11, 12], tf.dtypes.int64)\n      values = tf.constant([2, 3, 4], tf.dtypes.int64)\n      table = _ShardedMutableDenseHashTable(\n          tf.dtypes.int64,\n          tf.dtypes.int64,\n          default_val,\n          empty_key,\n          deleted_key,\n          num_shards=num_shards)\n      self.assertAllEqual(0, self.evaluate(table.size()))\n\n      self.evaluate(table.insert(keys, values))\n      self.assertAllEqual(3, self.evaluate(table.size()))\n\n      keys_list, values_list = table.export_sharded()\n      self.assertAllEqual(num_shards, len(keys_list))\n      self.assertAllEqual(num_shards, len(values_list))\n\n      # Exported keys include empty key buckets set to the empty_key\n      self.assertAllEqual(\n          set([-2, 10, 12]), set(self.evaluate(keys_list[0]).flatten()))\n      self.assertAllEqual(\n          set([-2, 11]), set(self.evaluate(keys_list[1]).flatten()))\n      # Exported values include empty value buckets set to 0\n      self.assertAllEqual(\n          set([0, 2, 4]), set(self.evaluate(values_list[0]).flatten()))\n      self.assertAllEqual(\n          set([0, 3]), set(self.evaluate(values_list[1]).flatten()))\n\n\nif __name__ == '__main__':\n  googletest.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for linear.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\n\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator.canned import linear\nfrom tensorflow_estimator.python.estimator.canned import linear_testing_utils\n\n\ndef _linear_regressor_fn(*args, **kwargs):\n  return linear.LinearRegressorV2(*args, **kwargs)\n\n\ndef _linear_classifier_fn(*args, **kwargs):\n  return linear.LinearClassifierV2(*args, **kwargs)\n\n\n# Tests for Linear Regressor.\n\n\nclass LinearRegressorEvaluationV2Test(\n    linear_testing_utils.BaseLinearRegressorEvaluationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\nclass LinearRegressorPredictV2Test(\n    linear_testing_utils.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\nclass LinearRegressorIntegrationV2Test(\n    linear_testing_utils.BaseLinearRegressorIntegrationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\nclass LinearRegressorTrainingV2Test(\n    linear_testing_utils.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n# Tests for Linear Classifier.\n\n\nclass LinearClassifierTrainingV2Test(\n    linear_testing_utils.BaseLinearClassifierTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\nclass LinearClassifierEvaluationV2Test(\n    linear_testing_utils.BaseLinearClassifierEvaluationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\nclass LinearClassifierPredictV2Test(\n    linear_testing_utils.BaseLinearClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearClassifierPredictTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\nclass LinearClassifierIntegrationV2Test(\n    linear_testing_utils.BaseLinearClassifierIntegrationTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n# Tests for Linear logit_fn.\n\n\nclass LinearLogitFnV2Test(linear_testing_utils.BaseLinearLogitFnTest,\n                          tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearLogitFnTest.__init__(\n        self, fc_lib=feature_column_v2)\n\n\n# Tests for warm-starting with Linear logit_fn.\n\n\nclass LinearWarmStartingV2Test(linear_testing_utils.BaseLinearWarmStartingTest,\n                               tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils.BaseLinearWarmStartingTest.__init__(\n        self,\n        _linear_classifier_fn,\n        _linear_regressor_fn,\n        fc_lib=feature_column_v2)\n\n\nclass ComputeFractionOfZeroTest(tf.test.TestCase):\n\n  def _assertSparsity(self, expected_sparsity, tensor):\n    sparsity = linear._compute_fraction_of_zero([tensor])\n    self.assertAllClose(expected_sparsity, sparsity)\n\n  def test_small_float32(self):\n    self._assertSparsity(\n        0.75, ops.convert_to_tensor([0, 0, 0, 1], dtype=tf.dtypes.float32))\n    self._assertSparsity(\n        0.5, ops.convert_to_tensor([0, 1, 0, 1], dtype=tf.dtypes.float32))\n\n  def test_small_int32(self):\n    self._assertSparsity(\n        0.75, ops.convert_to_tensor([0, 0, 0, 1], dtype=tf.dtypes.int32))\n\n  def test_small_float64(self):\n    self._assertSparsity(\n        0.75, ops.convert_to_tensor([0, 0, 0, 1], dtype=tf.dtypes.float64))\n\n  def test_small_int64(self):\n    self._assertSparsity(\n        0.75, ops.convert_to_tensor([0, 0, 0, 1], dtype=tf.dtypes.int64))\n\n  def test_nested(self):\n    self._assertSparsity(\n        0.75, [ops.convert_to_tensor([0, 0]),\n               ops.convert_to_tensor([0, 1])])\n\n  def test_none(self):\n    with self.assertRaises(ValueError):\n      linear._compute_fraction_of_zero([])\n\n  def test_empty(self):\n    sparsity = linear._compute_fraction_of_zero([ops.convert_to_tensor([])])\n    self.assertTrue(\n        self.evaluate(tf.math.is_nan(sparsity)),\n        'Expected sparsity=nan, got %s' % sparsity)\n\n  def test_multiple_empty(self):\n    sparsity = linear._compute_fraction_of_zero([\n        ops.convert_to_tensor([]),\n        ops.convert_to_tensor([]),\n    ])\n    self.assertTrue(\n        self.evaluate(tf.math.is_nan(sparsity)),\n        'Expected sparsity=nan, got %s' % sparsity)\n\n  def test_some_empty(self):\n    with self.test_session():\n      self._assertSparsity(0.5, [\n          ops.convert_to_tensor([]),\n          ops.convert_to_tensor([0.]),\n          ops.convert_to_tensor([1.]),\n      ])\n\n  def test_mixed_types(self):\n    with self.test_session():\n      self._assertSparsity(0.6, [\n          ops.convert_to_tensor([0, 0, 1, 1, 1], dtype=tf.dtypes.float32),\n          ops.convert_to_tensor([0, 0, 0, 0, 1], dtype=tf.dtypes.int32),\n      ])\n\n  def test_2_27_zeros__using_512_MiB_of_ram(self):\n    self._assertSparsity(1., tf.zeros([int(2**27 * 1.01)],\n                                      dtype=tf.dtypes.int8))\n\n  def test_2_27_ones__using_512_MiB_of_ram(self):\n    self._assertSparsity(0., tf.ones([int(2**27 * 1.01)], dtype=tf.dtypes.int8))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/linear_testing_utils.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utils for testing linear estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\nimport os\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import linear\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n# pylint rules which are disabled by default for test files.\n# pylint: disable=invalid-name,protected-access,missing-docstring\n\n# Names of variables created by model.\nAGE_WEIGHT_NAME = 'linear/linear_model/age/weights'\nHEIGHT_WEIGHT_NAME = 'linear/linear_model/height/weights'\nOCCUPATION_WEIGHT_NAME = 'linear/linear_model/occupation/weights'\nBIAS_NAME = 'linear/linear_model/bias_weights'\nLANGUAGE_WEIGHT_NAME = 'linear/linear_model/language/weights'\n\n\ndef assert_close(expected, actual, rtol=1e-04, name='assert_close'):\n  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:\n    expected = ops.convert_to_tensor(expected, name='expected')\n    actual = ops.convert_to_tensor(actual, name='actual')\n    rdiff = tf.math.abs(expected - actual, 'diff') / tf.math.abs(expected)\n    rtol = ops.convert_to_tensor(rtol, name='rtol')\n    return tf.compat.v1.debugging.assert_less(\n        rdiff,\n        rtol,\n        data=('Condition expected =~ actual did not hold element-wise:'\n              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,\n              'rtol = ', rtol,),\n        name=scope)\n\n\ndef save_variables_to_ckpt(model_dir):\n  init_all_op = [tf.compat.v1.initializers.global_variables()]\n  with tf.compat.v1.Session() as sess:\n    sess.run(init_all_op)\n    tf.compat.v1.train.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))\n\n\ndef queue_parsed_features(feature_map):\n  tensors_to_enqueue = []\n  keys = []\n  for key, tensor in six.iteritems(feature_map):\n    keys.append(key)\n    tensors_to_enqueue.append(tensor)\n  queue_dtypes = [x.dtype for x in tensors_to_enqueue]\n  input_queue = tf.queue.FIFOQueue(capacity=100, dtypes=queue_dtypes)\n  tf.compat.v1.train.queue_runner.add_queue_runner(\n      tf.compat.v1.train.queue_runner.QueueRunner(\n          input_queue, [input_queue.enqueue(tensors_to_enqueue)]))\n  dequeued_tensors = input_queue.dequeue()\n  return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}\n\n\ndef sorted_key_dict(unsorted_dict):\n  return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}\n\n\ndef sigmoid(x):\n  return 1 / (1 + np.exp(-1.0 * x))\n\n\ndef mock_optimizer(testcase, expected_loss=None):\n  expected_var_names = ['%s:0' % AGE_WEIGHT_NAME, '%s:0' % BIAS_NAME]\n\n  class _Optimizer(tf_keras.optimizers.legacy.Optimizer):\n\n    def get_updates(self, loss, params):\n      trainable_vars = params\n      testcase.assertItemsEqual(expected_var_names,\n                                [var.name for var in trainable_vars])\n\n      # Verify loss. We can't check the value directly, so we add an assert op.\n      testcase.assertEquals(0, loss.shape.ndims)\n      if expected_loss is None:\n        if self.iterations is not None:\n          return [self.iterations.assign_add(1).op]\n        return [tf.no_op()]\n      assert_loss = assert_close(\n          tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n          loss,\n          name='assert_loss')\n      with tf.control_dependencies((assert_loss,)):\n        if self.iterations is not None:\n          return [self.iterations.assign_add(1).op]\n        return [tf.no_op()]\n\n    def get_config(self):\n      config = super(_Optimizer, self).get_config()\n      return config\n\n  optimizer = _Optimizer(name='my_optimizer')\n\n  return optimizer\n\n\n# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.\nclass BaseLinearRegressorEvaluationTest(object):\n\n  def __init__(self, linear_regressor_fn, fc_lib=feature_column_v2):\n    self._linear_regressor_fn = linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_evaluation_for_simple_data(self):\n    with tf.Graph().as_default():\n      tf.Variable([[11.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([2.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir)\n    eval_metrics = linear_regressor.evaluate(\n        input_fn=lambda: ({\n            'age': ((1,),)\n        }, ((10.,),)), steps=1)\n\n    # Logit is (1. * 11.0 + 2.0) = 13, while label is 10. Loss is 3**2 = 9.\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 9.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_batch(self):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[11.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([2.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir)\n    eval_metrics = linear_regressor.evaluate(\n        input_fn=lambda: ({\n            'age': ((1,), (1,))\n        }, ((10.,), (10.,))), steps=1)\n\n    # Logit is (1. * 11.0 + 2.0) = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the sum over batch size = (9 + 9) / 2 = 9\n    # Average loss is the average over batch = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 9.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_weights(self):\n    \"\"\"Tests evaluation with weights.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[11.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([2.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))}\n      labels = ((10.,), (10.,))\n      return features, labels\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        weight_column='weights',\n        model_dir=self._model_dir)\n    eval_metrics = linear_regressor.evaluate(input_fn=_input_fn, steps=1)\n\n    # Logit is (1. * 11.0 + 2.0) = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the weighted sum over batch / batch size =\n    #     (9 + 2*9) / 2 = 13.5\n    # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 13.5,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_for_multi_dimensions(self):\n    x_dim = 3\n    label_dim = 2\n    with tf.Graph().as_default():\n      tf.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([7.0, 8.0], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age', shape=(x_dim,)),),\n        label_dimension=label_dim,\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': np.array([[2., 4., 5.]]),\n        },\n        y=np.array([[46., 58.]]),\n        batch_size=1,\n        num_epochs=None,\n        shuffle=False)\n    eval_metrics = linear_regressor.evaluate(input_fn=input_fn, steps=1)\n\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is\n    #   [2., 4., 5.] * [1.0, 2.0] + [7.0, 8.0] = [39, 50] + [7.0, 8.0]\n    #                  [3.0, 4.0]\n    #                  [5.0, 6.0]\n    # which is [46, 58]\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n  def test_evaluation_for_multiple_feature_columns(self):\n    with tf.Graph().as_default():\n      tf.Variable([[10.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([[2.0]], name=HEIGHT_WEIGHT_NAME)\n      tf.Variable([5.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    batch_size = 2\n    feature_columns = [\n        self._fc_lib.numeric_column('age'),\n        self._fc_lib.numeric_column('height')\n    ]\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': np.array([20, 40]),\n            'height': np.array([4, 8])\n        },\n        y=np.array([[213.], [421.]]),\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=False)\n\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns, model_dir=self._model_dir)\n\n    eval_metrics = est.evaluate(input_fn=input_fn, steps=1)\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is [(20. * 10.0 + 4 * 2.0 + 5.0), (40. * 10.0 + 8 * 2.0 + 5.0)] =\n    # [213.0, 421.0], while label is [213., 421.]. Loss = 0.\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n  def test_evaluation_for_multiple_feature_columns_mix(self):\n    with tf.Graph().as_default():\n      tf.Variable([[10.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([[2.0]], name=HEIGHT_WEIGHT_NAME)\n      tf.Variable([5.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    batch_size = 2\n    feature_columns = [\n        tf.feature_column.numeric_column('age'),\n        tf.feature_column.numeric_column('height')\n    ]\n\n    def _input_fn():\n      features_ds = tf.compat.v1.data.Dataset.from_tensor_slices({\n          'age': np.array([20, 40]),\n          'height': np.array([4, 8])\n      })\n      labels_ds = tf.compat.v1.data.Dataset.from_tensor_slices(\n          np.array([[213.], [421.]]))\n      return (tf.compat.v1.data.Dataset.zip(\n          (features_ds, labels_ds)).batch(batch_size).repeat(None))\n\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns, model_dir=self._model_dir)\n\n    eval_metrics = est.evaluate(input_fn=_input_fn, steps=1)\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is [(20. * 10.0 + 4 * 2.0 + 5.0), (40. * 10.0 + 8 * 2.0 + 5.0)] =\n    # [213.0, 421.0], while label is [213., 421.]. Loss = 0.\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n\nclass BaseLinearRegressorPredictTest(object):\n\n  def __init__(self, linear_regressor_fn, fc_lib=feature_column_v2):\n    self._linear_regressor_fn = linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_1d(self):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[10.]], name='linear/linear_model/x/weights')\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('x'),),\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[2.]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = linear_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x * weight + bias = 2. * 10. + .2 = 20.2\n    self.assertAllClose([[20.2]], predicted_scores)\n\n  def testMultiDim(self):\n    \"\"\"Tests predict when all variables are multi-dimenstional.\"\"\"\n    batch_size = 2\n    label_dimension = 3\n    x_dim = 4\n    feature_columns = (self._fc_lib.numeric_column('x', shape=(x_dim,)),)\n    with tf.Graph().as_default():\n      tf.Variable(  # shape=[x_dim, label_dimension]\n          [[1., 2., 3.], [2., 3., 4.], [3., 4., 5.], [4., 5., 6.]],\n          name='linear/linear_model/x/weights')\n      tf.Variable(  # shape=[label_dimension]\n          [.2, .4, .6], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        # x shape=[batch_size, x_dim]\n        x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predictions = linear_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # score = x * weight + bias, shape=[batch_size, label_dimension]\n    self.assertAllClose([[30.2, 40.4, 50.6], [70.2, 96.4, 122.6]],\n                        predicted_scores)\n\n  def testTwoFeatureColumns(self):\n    \"\"\"Tests predict with two feature columns.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[10.]], name='linear/linear_model/x0/weights')\n      tf.Variable([[20.]], name='linear/linear_model/x1/weights')\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('x0'),\n                         self._fc_lib.numeric_column('x1')),\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'x0': np.array([[2.]]),\n            'x1': np.array([[3.]])\n        },\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = linear_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2\n    self.assertAllClose([[80.2]], predicted_scores)\n\n  def testTwoFeatureColumnsMix(self):\n    \"\"\"Tests predict with two feature columns.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[10.]], name='linear/linear_model/x0/weights')\n      tf.Variable([[20.]], name='linear/linear_model/x1/weights')\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(tf.feature_column.numeric_column('x0'),\n                         tf.feature_column.numeric_column('x1')),\n        model_dir=self._model_dir)\n\n    def _predict_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensor_slices({\n          'x0': np.array([[2.]]),\n          'x1': np.array([[3.]])\n      }).batch(1)\n\n    predictions = linear_regressor.predict(input_fn=_predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2\n    self.assertAllClose([[80.2]], predicted_scores)\n\n  def testSparseCombiner(self):\n    w_a = 2.0\n    w_b = 3.0\n    w_c = 5.0\n    bias = 5.0\n    with tf.Graph().as_default():\n      tf.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME)\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          1, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      return tf.compat.v1.data.Dataset.from_tensors({\n          'language':\n              tf.sparse.SparseTensor(\n                  values=['a', 'c', 'b', 'c'],\n                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],\n                  dense_shape=[2, 2]),\n      })\n\n    feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(\n        'language', vocabulary_list=['a', 'b', 'c']),)\n\n    # Check prediction for each sparse_combiner.\n    # With sparse_combiner = 'sum', we have\n    # logits_1 = w_a + w_c + bias\n    #          = 2.0 + 5.0 + 5.0 = 12.0\n    # logits_2 = w_b + w_c + bias\n    #          = 3.0 + 5.0 + 5.0 = 13.0\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=feature_columns, model_dir=self._model_dir)\n    predictions = linear_regressor.predict(input_fn=_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    self.assertAllClose([[12.0], [13.0]], predicted_scores)\n\n    # With sparse_combiner = 'mean', we have\n    # logits_1 = 1/2 * (w_a + w_c) + bias\n    #          = 1/2 * (2.0 + 5.0) + 5.0 = 8.5\n    # logits_2 = 1/2 * (w_b + w_c) + bias\n    #          = 1/2 * (3.0 + 5.0) + 5.0 = 9.0\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        model_dir=self._model_dir,\n        sparse_combiner='mean')\n    predictions = linear_regressor.predict(input_fn=_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    self.assertAllClose([[8.5], [9.0]], predicted_scores)\n\n    # With sparse_combiner = 'sqrtn', we have\n    # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias\n    #          = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974\n    # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias\n    #          = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        model_dir=self._model_dir,\n        sparse_combiner='sqrtn')\n    predictions = linear_regressor.predict(input_fn=_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    self.assertAllClose([[9.94974], [10.65685]], predicted_scores)\n\n\nclass BaseLinearRegressorIntegrationTest(object):\n\n  def __init__(self, linear_regressor_fn, fc_lib=feature_column_v2):\n    self._linear_regressor_fn = linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, prediction_length):\n    feature_columns = [\n        self._fc_lib.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['predictions'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, label_dimension), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n\n    # Pandas DataFrame natually supports 1 dim data only.\n    label_dimension = 1\n    input_dimension = label_dimension\n    batch_size = 10\n    data = np.array([1., 2., 3., 4.], dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(data)\n    prediction_length = 4\n\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(\n                              value=datum[:label_dimension])),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n\nclass BaseLinearRegressorTrainingTest(object):\n\n  def __init__(self, linear_regressor_fn, fc_lib=feature_column_v2):\n    self._linear_regressor_fn = linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _assert_checkpoint(self,\n                         expected_global_step,\n                         expected_age_weight=None,\n                         expected_bias=None):\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([1, 1], shapes[AGE_WEIGHT_NAME])\n    if expected_age_weight is not None:\n      self.assertEqual(expected_age_weight,\n                       tf.train.load_variable(self._model_dir, AGE_WEIGHT_NAME))\n\n    self.assertEqual([1], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertEqual(expected_bias,\n                       tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def testFromScratchWithDefaultOptimizer(self):\n    # Create LinearRegressor.\n    label = 5.\n    age = 17\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    linear_regressor.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self._assert_checkpoint(num_steps)\n\n  def testTrainWithOneDimLabel(self):\n    label_dimension = 1\n    batch_size = 20\n    feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)\n    self.assertEqual((batch_size,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(200)\n\n  def testTrainWithOneDimWeight(self):\n    label_dimension = 1\n    batch_size = 20\n    feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        weight_column='w',\n        model_dir=self._model_dir)\n\n    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)\n    self.assertEqual((batch_size,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_1\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(200)\n\n  def testFromScratch(self):\n    # Create LinearRegressor.\n    label = 5.\n    age = 17\n    # loss = (logits - label)^2 = (0 - 5.)^2 = 25.\n    mock_opt = mock_optimizer(self, expected_loss=25.)\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir,\n        optimizer=mock_opt)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    linear_regressor.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(\n        num_steps,\n        linear_regressor.get_variable_value(mock_opt.iterations.name))\n    self._assert_checkpoint(\n        expected_global_step=num_steps,\n        expected_age_weight=0.,\n        expected_bias=0.)\n\n  def testFromCheckpoint(self):\n    # Create initial checkpoint.\n    age_weight = 10.0\n    bias = 5.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([[age_weight]], name=AGE_WEIGHT_NAME)\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = age * age_weight + bias = 17 * 10. + 5. = 175\n    # loss = (logits - label)^2 = (175 - 5)^2 = 28900\n    mock_opt = mock_optimizer(self, expected_loss=28900.)\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir,\n        optimizer=mock_opt)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    linear_regressor.train(\n        input_fn=lambda: ({\n            'age': ((17,),)\n        }, ((5.,),)), steps=num_steps)\n    self.assertEqual(\n        initial_global_step + num_steps,\n        linear_regressor.get_variable_value(mock_opt.iterations.name))\n    self._assert_checkpoint(\n        expected_global_step=initial_global_step + num_steps,\n        expected_age_weight=age_weight,\n        expected_bias=bias)\n\n  def testFromCheckpointMultiBatch(self):\n    # Create initial checkpoint.\n    age_weight = 10.0\n    bias = 5.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([[age_weight]], name=AGE_WEIGHT_NAME)\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = age * age_weight + bias\n    # logits[0] = 17 * 10. + 5. = 175\n    # logits[1] = 15 * 10. + 5. = 155\n    # loss = sum(logits - label)^2 = (175 - 5)^2 + (155 - 3)^2 = 52004\n    # expected_loss = loss / 2 = 26002\n    mock_opt = mock_optimizer(self, expected_loss=26002.)\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir,\n        optimizer=mock_opt)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    linear_regressor.train(\n        input_fn=lambda: ({\n            'age': ((17,), (15,))\n        }, ((5.,), (3.,))),\n        steps=num_steps)\n    self.assertEqual(\n        initial_global_step + num_steps,\n        linear_regressor.get_variable_value(mock_opt.iterations.name))\n    self._assert_checkpoint(\n        expected_global_step=initial_global_step + num_steps,\n        expected_age_weight=age_weight,\n        expected_bias=bias)\n\n\nclass BaseLinearClassifierTrainingTest(object):\n\n  def __init__(self, linear_classifier_fn, fc_lib=feature_column_v2):\n    self._linear_classifier_fn = linear_classifier_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _assert_checkpoint(self,\n                         n_classes,\n                         expected_global_step,\n                         expected_age_weight=None,\n                         expected_bias=None):\n    logits_dimension = n_classes if n_classes > 2 else 1\n\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([1, logits_dimension], shapes[AGE_WEIGHT_NAME])\n    if expected_age_weight is not None:\n      self.assertAllEqual(\n          expected_age_weight,\n          tf.train.load_variable(self._model_dir, AGE_WEIGHT_NAME))\n\n    self.assertEqual([logits_dimension], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertAllEqual(expected_bias,\n                          tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def _testFromScratchWithDefaultOptimizer(self, n_classes):\n    label = 0\n    age = 17\n    est = linear.LinearClassifierV2(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self._assert_checkpoint(n_classes, num_steps)\n\n  def testBinaryClassesFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=2)\n\n  def testMultiClassesFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=4)\n\n  def _testTrainWithTwoDimsLabel(self, n_classes):\n    batch_size = 20\n\n    est = linear.LinearClassifierV2(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    data_rank_2 = np.array([[0], [1]])\n    self.assertEqual((2,), data_rank_1.shape)\n    self.assertEqual((2, 1), data_rank_2.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_2,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithTwoDimsLabel(self):\n    self._testTrainWithTwoDimsLabel(n_classes=2)\n\n  def testMultiClassesTrainWithTwoDimsLabel(self):\n    self._testTrainWithTwoDimsLabel(n_classes=4)\n\n  def _testTrainWithOneDimLabel(self, n_classes):\n    batch_size = 20\n\n    est = linear.LinearClassifierV2(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    self.assertEqual((2,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithOneDimLabel(self):\n    self._testTrainWithOneDimLabel(n_classes=2)\n\n  def testMultiClassesTrainWithOneDimLabel(self):\n    self._testTrainWithOneDimLabel(n_classes=4)\n\n  def _testTrainWithTwoDimsWeight(self, n_classes):\n    batch_size = 20\n\n    est = linear.LinearClassifierV2(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        weight_column='w',\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    data_rank_2 = np.array([[0], [1]])\n    self.assertEqual((2,), data_rank_1.shape)\n    self.assertEqual((2, 1), data_rank_2.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_2\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithTwoDimsWeight(self):\n    self._testTrainWithTwoDimsWeight(n_classes=2)\n\n  def testMultiClassesTrainWithTwoDimsWeight(self):\n    self._testTrainWithTwoDimsWeight(n_classes=4)\n\n  def _testTrainWithOneDimWeight(self, n_classes):\n    batch_size = 20\n\n    est = linear.LinearClassifierV2(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        weight_column='w',\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    self.assertEqual((2,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_1\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithOneDimWeight(self):\n    self._testTrainWithOneDimWeight(n_classes=2)\n\n  def testMultiClassesTrainWithOneDimWeight(self):\n    self._testTrainWithOneDimWeight(n_classes=4)\n\n  def _testFromScratch(self, n_classes):\n    label = 1\n    age = 17\n    # For binary classifier:\n    #   loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are\n    #   all zero initially) and label = 1 so,\n    #      loss = 1 * -log ( sigmoid(logits) ) = 0.69315\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label) where logits are all 0s (weights are\n    #   all zero initially) and label = 1 so,\n    #      loss = 1 * -log ( 1.0 / n_classes )\n    # For this particular test case, as logits are same, the formular\n    # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases.\n    mock_opt = mock_optimizer(\n        self, expected_loss=-1 * math.log(1.0 / n_classes))\n\n    est = linear.LinearClassifierV2(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        optimizer=mock_opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(num_steps,\n                     est.get_variable_value(mock_opt.iterations.name))\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=num_steps,\n        expected_age_weight=[[0.]] if n_classes == 2 else [[0.] * n_classes],\n        expected_bias=[0.] if n_classes == 2 else [.0] * n_classes)\n\n  def testBinaryClassesFromScratch(self):\n    self._testFromScratch(n_classes=2)\n\n  def testMultiClassesFromScratch(self):\n    self._testFromScratch(n_classes=4)\n\n  def _testFromCheckpoint(self, n_classes):\n    # Create initial checkpoint.\n    label = 1\n    age = 17\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[2.0]] if n_classes == 2 else (np.reshape(\n        2.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # For binary classifier:\n    #   logits = age * age_weight + bias = 17 * 2. - 35. = -1.\n    #   loss = sigmoid_cross_entropy(logits, label)\n    #   so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label)\n    #   where logits = 17 * age_weight + bias and label = 1\n    #   so, loss = 1 * -log ( soft_max(logits)[1] )\n    if n_classes == 2:\n      expected_loss = 1.3133\n    else:\n      logits = age_weight * age + bias\n      logits_exp = np.exp(logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_loss = -1 * math.log(softmax[0, label])\n\n    mock_opt = mock_optimizer(self, expected_loss=expected_loss)\n\n    est = linear.LinearClassifierV2(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        optimizer=mock_opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(initial_global_step + num_steps,\n                     est.get_variable_value(mock_opt.iterations.name))\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=initial_global_step + num_steps,\n        expected_age_weight=age_weight,\n        expected_bias=bias)\n\n  def testBinaryClassesFromCheckpoint(self):\n    self._testFromCheckpoint(n_classes=2)\n\n  def testMultiClassesFromCheckpoint(self):\n    self._testFromCheckpoint(n_classes=4)\n\n  def _testFromCheckpointFloatLabels(self, n_classes):\n    \"\"\"Tests float labels for binary classification.\"\"\"\n    # Create initial checkpoint.\n    if n_classes > 2:\n      return\n    label = 0.8\n    age = 17\n    age_weight = [[2.0]]\n    bias = [-35.0]\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = age * age_weight + bias = 17 * 2. - 35. = -1.\n    # loss = sigmoid_cross_entropy(logits, label)\n    # => loss = -0.8 * log(sigmoid(-1)) -0.2 * log(sigmoid(+1)) = 1.1132617\n    mock_opt = mock_optimizer(self, expected_loss=1.1132617)\n\n    est = linear.LinearClassifierV2(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        optimizer=mock_opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(initial_global_step + num_steps,\n                     est.get_variable_value(mock_opt.iterations.name))\n\n  def testBinaryClassesFromCheckpointFloatLabels(self):\n    self._testFromCheckpointFloatLabels(n_classes=2)\n\n  def testMultiClassesFromCheckpointFloatLabels(self):\n    self._testFromCheckpointFloatLabels(n_classes=4)\n\n  def _testFromCheckpointMultiBatch(self, n_classes):\n    # Create initial checkpoint.\n    label = [1, 0]\n    age = [17.0, 18.5]\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[2.0]] if n_classes == 2 else (np.reshape(\n        2.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # For binary classifier:\n    #   logits = age * age_weight + bias\n    #   logits[0] = 17 * 2. - 35. = -1.\n    #   logits[1] = 18.5 * 2. - 35. = 2.\n    #   loss = sigmoid_cross_entropy(logits, label)\n    #   so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133\n    #       loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269\n    #   expected_loss = (loss[0] + loss[1]) / batch size (2)\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label)\n    #   where logits = [17, 18.5] * age_weight + bias and label = [1, 0]\n    #   so, loss = 1 * -log ( soft_max(logits)[label] )\n    #   expected_loss = (loss[0] + loss[1]) / batch size (2)\n    if n_classes == 2:\n      expected_loss = (1.3133 + 2.1269) / 2\n    else:\n      logits = age_weight * np.reshape(age, (2, 1)) + bias\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      expected_loss = (expected_loss_0 + expected_loss_1) / 2\n\n    mock_opt = mock_optimizer(self, expected_loss=expected_loss)\n\n    est = linear.LinearClassifierV2(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        optimizer=mock_opt,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(input_fn=lambda: ({'age': (age)}, (label)), steps=num_steps)\n    self.assertEqual(initial_global_step + num_steps,\n                     est.get_variable_value(mock_opt.iterations.name))\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=initial_global_step + num_steps,\n        expected_age_weight=age_weight,\n        expected_bias=bias)\n\n  def testBinaryClassesFromCheckpointMultiBatch(self):\n    self._testFromCheckpointMultiBatch(n_classes=2)\n\n  def testMultiClassesFromCheckpointMultiBatch(self):\n    self._testFromCheckpointMultiBatch(n_classes=4)\n\n\nclass BaseLinearClassifierEvaluationTest(object):\n\n  def __init__(self, linear_classifier_fn, fc_lib=feature_column_v2):\n    self._linear_classifier_fn = linear_classifier_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _test_evaluation_for_simple_data(self, n_classes):\n    label = 1\n    age = 1.\n\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[-11.0]] if n_classes == 2 else (np.reshape(\n        -11.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-30.0] if n_classes == 2 else [-30.0] * n_classes\n\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = self._linear_classifier_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=1)\n\n    if n_classes == 2:\n      # Binary classes: loss = sum(corss_entropy(41)) = 41.\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: 41.,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: 41.,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: 0.,\n          metric_keys.MetricKeys.LABEL_MEAN: 1.,\n          metric_keys.MetricKeys.ACCURACY_BASELINE: 1,\n          metric_keys.MetricKeys.AUC: 0.,\n          metric_keys.MetricKeys.AUC_PR: 1.,\n      }\n    else:\n      # Multi classes: loss = 1 * -log ( soft_max(logits)[label] )\n      logits = age_weight * age + bias\n      logits_exp = np.exp(logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_loss = -1 * math.log(softmax[0, label])\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_for_simple_data(self):\n    self._test_evaluation_for_simple_data(n_classes=2)\n\n  def test_multi_classes_evaluation_for_simple_data(self):\n    self._test_evaluation_for_simple_data(n_classes=4)\n\n  def _test_evaluation_batch(self, n_classes):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    label = [1, 0]\n    age = [17., 18.]\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[2.0]] if n_classes == 2 else (np.reshape(\n        2.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = self._linear_classifier_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({\n            'age': (age)\n        }, (label)), steps=1)\n\n    if n_classes == 2:\n      # Logits are (-1., 1.) labels are (1, 0).\n      # Loss is\n      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133\n      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(1)) = 1.3133\n      expected_loss = (1.3133 * 2) / 2  # Divided by batch size\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: 0.5,\n          metric_keys.MetricKeys.LABEL_MEAN: 0.5,\n          metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,\n          metric_keys.MetricKeys.AUC: 0.,\n          metric_keys.MetricKeys.AUC_PR: 0.3068,\n      }\n    else:\n      # Multi classes: loss = 1 * -log ( soft_max(logits)[label] )\n      logits = age_weight * np.reshape(age, (2, 1)) + bias\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      expected_loss = (expected_loss_0 + expected_loss_1) / 2  # batch size\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_batch(self):\n    self._test_evaluation_batch(n_classes=2)\n\n  def test_multi_classes_evaluation_batch(self):\n    self._test_evaluation_batch(n_classes=4)\n\n  def _test_evaluation_weights(self, n_classes):\n    \"\"\"Tests evaluation with weights.\"\"\"\n\n    label = [1, 0]\n    age = [17., 18.]\n    weights = [1., 2.]\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[2.0]] if n_classes == 2 else (np.reshape(\n        2.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = self._linear_classifier_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        weight_column='w',\n        model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({\n            'age': (age),\n            'w': (weights)\n        }, (label)), steps=1)\n\n    if n_classes == 2:\n      # Logits are (-1., 1.) labels are (1, 0).\n      # Loss is\n      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133\n      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(1)) = 1.3133\n      #   weights = [1., 2.]\n      expected_loss = (1.3133 * (1. + 2.)) / 2  # Divided by batch size\n      loss_mean = (1.3133 * (1. + 2.)) / (1.0 + 2.0)\n      label_mean = np.average(label, weights=weights)\n      logits = [-1, 1]\n      logistics = sigmoid(np.array(logits))\n      predictions_mean = np.average(logistics, weights=weights)\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean,\n          metric_keys.MetricKeys.LABEL_MEAN: label_mean,\n          metric_keys.MetricKeys.ACCURACY_BASELINE:\n              (max(label_mean, 1 - label_mean)),\n          metric_keys.MetricKeys.AUC: 0.,\n          metric_keys.MetricKeys.AUC_PR: 0.1891,\n      }\n    else:\n      # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )\n      logits = age_weight * np.reshape(age, (2, 1)) + bias\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      loss_mean = np.average([expected_loss_0, expected_loss_1],\n                             weights=weights)\n      expected_loss = (loss_mean * np.sum(weights)) / 2  # batch size\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_weights(self):\n    self._test_evaluation_weights(n_classes=2)\n\n  def test_multi_classes_evaluation_weights(self):\n    self._test_evaluation_weights(n_classes=4)\n\n\nclass BaseLinearClassifierPredictTest(object):\n\n  def __init__(self, linear_classifier_fn, fc_lib=feature_column_v2):\n    self._linear_classifier_fn = linear_classifier_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _testPredictions(self, n_classes, label_vocabulary, label_output_fn):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    age = 1.\n\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[-11.0]] if n_classes == 2 else (np.reshape(\n        -11.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [10.0] if n_classes == 2 else [10.0] * n_classes\n\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = self._linear_classifier_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        label_vocabulary=label_vocabulary,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'age': np.array([[age]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = list(est.predict(input_fn=predict_input_fn))\n\n    if n_classes == 2:\n      scalar_logits = np.reshape(np.array(age_weight) * age + bias, (1,)).item()\n      two_classes_logits = [0, scalar_logits]\n      two_classes_logits_exp = np.exp(two_classes_logits)\n      softmax = two_classes_logits_exp / two_classes_logits_exp.sum()\n\n      expected_predictions = {\n          'class_ids': [0],\n          'all_class_ids': [0, 1],\n          'classes': [label_output_fn(0)],\n          'all_classes': [label_output_fn(0),\n                          label_output_fn(1)],\n          'logistic': [sigmoid(np.array(scalar_logits))],\n          'logits': [scalar_logits],\n          'probabilities': softmax,\n      }\n    else:\n      onedim_logits = np.reshape(np.array(age_weight) * age + bias, (-1,))\n      class_ids = onedim_logits.argmax()\n      all_class_ids = list(range(len(onedim_logits)))\n      logits_exp = np.exp(onedim_logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_predictions = {\n          'class_ids': [class_ids],\n          'all_class_ids': all_class_ids,\n          'classes': [label_output_fn(class_ids)],\n          'all_classes': [label_output_fn(i) for i in all_class_ids],\n          'logits': onedim_logits,\n          'probabilities': softmax,\n      }\n\n    self.assertEqual(1, len(predictions))\n    # assertAllClose cannot handle byte type.\n    self.assertEqual(expected_predictions['classes'], predictions[0]['classes'])\n    expected_predictions.pop('classes')\n    predictions[0].pop('classes')\n    self.assertAllEqual(expected_predictions['all_classes'],\n                        predictions[0]['all_classes'])\n    expected_predictions.pop('all_classes')\n    predictions[0].pop('all_classes')\n    self.assertAllClose(\n        sorted_key_dict(expected_predictions), sorted_key_dict(predictions[0]))\n\n  def testBinaryClassesWithoutLabelVocabulary(self):\n    n_classes = 2\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=None,\n        label_output_fn=lambda x: ('%s' % x).encode())\n\n  def testBinaryClassesWithLabelVocabulary(self):\n    n_classes = 2\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n  def testMultiClassesWithoutLabelVocabulary(self):\n    n_classes = 4\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=None,\n        label_output_fn=lambda x: ('%s' % x).encode())\n\n  def testMultiClassesWithLabelVocabulary(self):\n    n_classes = 4\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n  def testSparseCombiner(self):\n    w_a = 2.0\n    w_b = 3.0\n    w_c = 5.0\n    bias = 5.0\n    with tf.Graph().as_default():\n      tf.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME)\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          1, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      return tf.compat.v1.data.Dataset.from_tensors({\n          'language':\n              tf.sparse.SparseTensor(\n                  values=['a', 'c', 'b', 'c'],\n                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],\n                  dense_shape=[2, 2]),\n      })\n\n    feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(\n        'language', vocabulary_list=['a', 'b', 'c']),)\n\n    # Check prediction for each sparse_combiner.\n    # With sparse_combiner = 'sum', we have\n    # logits_1 = w_a + w_c + bias\n    #          = 2.0 + 5.0 + 5.0 = 12.0\n    # logits_2 = w_b + w_c + bias\n    #          = 3.0 + 5.0 + 5.0 = 13.0\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=feature_columns, model_dir=self._model_dir)\n    predictions = linear_classifier.predict(input_fn=_input_fn)\n    predicted_scores = list([x['logits'] for x in predictions])\n    self.assertAllClose([[12.0], [13.0]], predicted_scores)\n\n    # With sparse_combiner = 'mean', we have\n    # logits_1 = 1/2 * (w_a + w_c) + bias\n    #          = 1/2 * (2.0 + 5.0) + 5.0 = 8.5\n    # logits_2 = 1/2 * (w_b + w_c) + bias\n    #          = 1/2 * (3.0 + 5.0) + 5.0 = 9.0\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=feature_columns,\n        model_dir=self._model_dir,\n        sparse_combiner='mean')\n    predictions = linear_classifier.predict(input_fn=_input_fn)\n    predicted_scores = list([x['logits'] for x in predictions])\n    self.assertAllClose([[8.5], [9.0]], predicted_scores)\n\n    # With sparse_combiner = 'sqrtn', we have\n    # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias\n    #          = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974\n    # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias\n    #          = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=feature_columns,\n        model_dir=self._model_dir,\n        sparse_combiner='sqrtn')\n    predictions = linear_classifier.predict(input_fn=_input_fn)\n    predicted_scores = list([x['logits'] for x in predictions])\n    self.assertAllClose([[9.94974], [10.65685]], predicted_scores)\n\n\nclass BaseLinearClassifierIntegrationTest(object):\n\n  def __init__(self, linear_classifier_fn, fc_lib=feature_column_v2):\n    self._linear_classifier_fn = linear_classifier_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,\n                          predict_input_fn, input_dimension, prediction_length):\n    feature_columns = [\n        self._fc_lib.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = self._linear_classifier_fn(\n        feature_columns=feature_columns,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['classes'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, 1), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _test_numpy_input_fn(self, n_classes):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    input_dimension = 4\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n    target = np.array([1] * batch_size)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=target,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=target,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_numpy_input_fn(self):\n    self._test_numpy_input_fn(n_classes=2)\n\n  def test_multi_classes_numpy_input_fn(self):\n    self._test_numpy_input_fn(n_classes=4)\n\n  def _test_pandas_input_fn(self, n_classes):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n\n    # Pandas DataFrame natually supports 1 dim data only.\n    input_dimension = 1\n    batch_size = 10\n    data = np.array([1., 2., 3., 4.], dtype=np.float32)\n    target = np.array([1, 0, 1, 0], dtype=np.int32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(target)\n    prediction_length = 4\n\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_pandas_input_fn(self):\n    self._test_pandas_input_fn(n_classes=2)\n\n  def test_multi_classes_pandas_input_fn(self):\n    self._test_pandas_input_fn(n_classes=4)\n\n  def _test_input_fn_from_parse_example(self, n_classes):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    input_dimension = 2\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n    target = np.array([1] * batch_size, dtype=np.int64)\n\n    serialized_examples = []\n    for x, y in zip(data, target):\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=x)),\n                  'y':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(value=[y])),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([1], tf.dtypes.int64),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_input_fn_from_parse_example(self):\n    self._test_input_fn_from_parse_example(n_classes=2)\n\n  def test_multi_classes_input_fn_from_parse_example(self):\n    self._test_input_fn_from_parse_example(n_classes=4)\n\n\nclass BaseLinearLogitFnTest(object):\n\n  def __init__(self, fc_lib=feature_column_v2):\n    self._fc_lib = fc_lib\n\n  def test_basic_logit_correctness(self):\n    \"\"\"linear_logit_fn simply wraps feature_column_lib.linear_model.\"\"\"\n    age = self._fc_lib.numeric_column('age')\n    with tf.Graph().as_default():\n      logit_fn = linear.linear_logit_fn_builder(units=2, feature_columns=[age])\n      logits = logit_fn(features={'age': [[23.], [31.]]})\n      bias_var = tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,\n          'linear_model/bias_weights')[0]\n      age_var = tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, 'linear_model/age')[0]\n      with tf.compat.v1.Session() as sess:\n        sess.run([tf.compat.v1.initializers.global_variables()])\n        self.assertAllClose([[0., 0.], [0., 0.]], logits.eval())\n        sess.run(bias_var.assign([10., 5.]))\n        self.assertAllClose([[10., 5.], [10., 5.]], logits.eval())\n        sess.run(age_var.assign([[2.0, 3.0]]))\n        # [2 * 23 + 10, 3 * 23 + 5] = [56, 74].\n        # [2 * 31 + 10, 3 * 31 + 5] = [72, 98]\n        self.assertAllClose([[56., 74.], [72., 98.]], logits.eval())\n\n  def test_compute_fraction_of_zero_v2(self):\n    \"\"\"Tests the calculation of sparsity.\"\"\"\n    if self._fc_lib != feature_column_v2:\n      return\n\n    age = tf.feature_column.numeric_column('age')\n    occupation = tf.feature_column.categorical_column_with_hash_bucket(\n        'occupation', hash_bucket_size=5)\n    with tf.Graph().as_default():\n      model = linear.LinearModel(\n          feature_columns=[age, occupation], units=3, name='linear_model')\n      features = {\n          'age': [[23.], [31.]],\n          'occupation': [['doctor'], ['engineer']]\n      }\n      model(features)\n      variables = model.variables\n      variables.remove(model.bias)\n      fraction_zero = linear._compute_fraction_of_zero(variables)\n      age_var = tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, 'linear_model/age')[0]\n      with tf.compat.v1.Session() as sess:\n        sess.run([tf.compat.v1.initializers.global_variables()])\n        # Upon initialization, all variables will be zero.\n        self.assertAllClose(1, fraction_zero.eval())\n\n        sess.run(age_var.assign([[2.0, 0.0, -1.0]]))\n        # 1 of the 3 age weights are zero, and all of the 15 (5 hash buckets\n        # x 3-dim output) are zero.\n        self.assertAllClose(16. / 18., fraction_zero.eval())\n\n\nclass BaseLinearWarmStartingTest(object):\n\n  def __init__(self,\n               _linear_classifier_fn,\n               _linear_regressor_fn,\n               fc_lib=feature_column_v2):\n    self._linear_classifier_fn = _linear_classifier_fn\n    self._linear_regressor_fn = _linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    # Create a directory to save our old checkpoint and vocabularies to.\n    self._ckpt_and_vocab_dir = tempfile.mkdtemp()\n\n    # Make a dummy input_fn.\n    def _input_fn():\n      features = {\n          'age': [[23.], [31.]],\n          'age_in_years': [[23.], [31.]],\n          'occupation': [['doctor'], ['consultant']]\n      }\n      return features, [0, 1]\n\n    self._input_fn = _input_fn\n\n  def tearDown(self):\n    # Clean up checkpoint / vocab dir.\n    tf.compat.v1.summary.FileWriterCache.clear()\n    shutil.rmtree(self._ckpt_and_vocab_dir)\n\n  def test_classifier_basic_warm_starting(self):\n    \"\"\"Tests correctness of LinearClassifier default warm-start.\"\"\"\n    age = self._fc_lib.numeric_column('age')\n\n    # Create a LinearClassifier and train to save a checkpoint.\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age],\n        n_classes=4,\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        warm_start_from=linear_classifier.model_dir)\n\n    warm_started_linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_linear_classifier.get_variable_names():\n      # Learning rate is also checkpointed in V2 optimizer. So we need to make\n      # sure it uses the new value after warm started.\n      if 'learning_rate' in variable_name:\n        self.assertAllClose(\n            0.0,\n            warm_started_linear_classifier.get_variable_value(variable_name))\n      else:\n        self.assertAllClose(\n            linear_classifier.get_variable_value(variable_name),\n            warm_started_linear_classifier.get_variable_value(variable_name))\n\n  def test_regressor_basic_warm_starting(self):\n    \"\"\"Tests correctness of LinearRegressor default warm-start.\"\"\"\n    age = self._fc_lib.numeric_column('age')\n\n    # Create a LinearRegressor and train to save a checkpoint.\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=[age],\n        model_dir=self._ckpt_and_vocab_dir,\n        optimizer='SGD')\n    linear_regressor.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearRegressor, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_linear_regressor = self._linear_regressor_fn(\n        feature_columns=[age],\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        warm_start_from=linear_regressor.model_dir)\n\n    warm_started_linear_regressor.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_linear_regressor.get_variable_names():\n      # Learning rate is also checkpointed in V2 optimizer. So we need to make\n      # sure it uses the new value after warm started.\n      if 'learning_rate' in variable_name:\n        self.assertAllClose(\n            0.0,\n            warm_started_linear_regressor.get_variable_value(variable_name))\n      else:\n        self.assertAllClose(\n            linear_regressor.get_variable_value(variable_name),\n            warm_started_linear_regressor.get_variable_value(variable_name))\n\n  def test_warm_starting_selective_variables(self):\n    \"\"\"Tests selecting variables to warm-start.\"\"\"\n    age = self._fc_lib.numeric_column('age')\n\n    # Create a LinearClassifier and train to save a checkpoint.\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age],\n        n_classes=4,\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        # The provided regular expression will only warm-start the age variable\n        # and not the bias.\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=linear_classifier.model_dir,\n            vars_to_warm_start='.*(age).*'))\n\n    warm_started_linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n    self.assertAllClose(\n        linear_classifier.get_variable_value(AGE_WEIGHT_NAME),\n        warm_started_linear_classifier.get_variable_value(AGE_WEIGHT_NAME))\n    # Bias should still be zero from initialization.\n    self.assertAllClose(\n        [0.0] * 4, warm_started_linear_classifier.get_variable_value(BIAS_NAME))\n\n  def test_warm_starting_with_vocab_remapping_and_partitioning(self):\n    \"\"\"Tests warm-starting with vocab remapping and partitioning.\"\"\"\n    vocab_list = ['doctor', 'lawyer', 'consultant']\n    vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')\n    with open(vocab_file, 'w') as f:\n      f.write('\\n'.join(vocab_list))\n    occupation = self._fc_lib.categorical_column_with_vocabulary_file(\n        'occupation',\n        vocabulary_file=vocab_file,\n        vocabulary_size=len(vocab_list))\n\n    # Create a LinearClassifier and train to save a checkpoint.\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=[occupation],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).  Use a new FeatureColumn with a\n    # different vocabulary for occupation.\n    new_vocab_list = ['doctor', 'consultant', 'engineer']\n    new_vocab_file = os.path.join(self._ckpt_and_vocab_dir,\n                                  'new_occupation_vocab')\n    with open(new_vocab_file, 'w') as f:\n      f.write('\\n'.join(new_vocab_list))\n    new_occupation = self._fc_lib.categorical_column_with_vocabulary_file(\n        'occupation',\n        vocabulary_file=new_vocab_file,\n        vocabulary_size=len(new_vocab_list))\n    # We can create our VocabInfo object from the new and old occupation\n    # FeatureColumn's.\n    occupation_vocab_info = estimator.VocabInfo(\n        new_vocab=new_occupation.vocabulary_file,\n        new_vocab_size=new_occupation.vocabulary_size,\n        num_oov_buckets=new_occupation.num_oov_buckets,\n        old_vocab=occupation.vocabulary_file,\n        old_vocab_size=occupation.vocabulary_size,\n        # Can't use constant_initializer with load_and_remap.  In practice,\n        # use a truncated normal initializer.\n        backup_initializer=tf.compat.v1.initializers.random_uniform(\n            minval=0.39, maxval=0.39))\n    warm_started_linear_classifier = self._linear_classifier_fn(\n        feature_columns=[occupation],\n        n_classes=4,\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=linear_classifier.model_dir,\n            var_name_to_vocab_info={\n                OCCUPATION_WEIGHT_NAME: occupation_vocab_info\n            },\n            # Explicitly providing None here will only warm-start variables\n            # referenced in var_name_to_vocab_info (the bias will not be\n            # warm-started).\n            vars_to_warm_start=None))\n\n    warm_started_linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n    # 'doctor' was ID-0 and still ID-0.\n    self.assertAllClose(\n        linear_classifier.get_variable_value(OCCUPATION_WEIGHT_NAME)[0, :],\n        warm_started_linear_classifier.get_variable_value(\n            OCCUPATION_WEIGHT_NAME)[0, :])\n    # 'consultant' was ID-2 and now ID-1.\n    self.assertAllClose(\n        linear_classifier.get_variable_value(OCCUPATION_WEIGHT_NAME)[2, :],\n        warm_started_linear_classifier.get_variable_value(\n            OCCUPATION_WEIGHT_NAME)[1, :])\n    # 'engineer' is a new entry and should be initialized with the\n    # backup_initializer in VocabInfo.\n    self.assertAllClose([0.39] * 4,\n                        warm_started_linear_classifier.get_variable_value(\n                            OCCUPATION_WEIGHT_NAME)[2, :])\n    # Bias should still be zero (from initialization logic).\n    self.assertAllClose(\n        [0.0] * 4, warm_started_linear_classifier.get_variable_value(BIAS_NAME))\n\n  def test_warm_starting_with_naming_change(self):\n    \"\"\"Tests warm-starting with a Tensor name remapping.\"\"\"\n    age_in_years = self._fc_lib.numeric_column('age_in_years')\n\n    # Create a LinearClassifier and train to save a checkpoint.\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age_in_years],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_linear_classifier = self._linear_classifier_fn(\n        feature_columns=[self._fc_lib.numeric_column('age')],\n        n_classes=4,\n        optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.0),\n        # The 'age' variable correspond to the 'age_in_years' variable in the\n        # previous model.\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=linear_classifier.model_dir,\n            var_name_to_prev_var_name={\n                AGE_WEIGHT_NAME: AGE_WEIGHT_NAME.replace('age', 'age_in_years')\n            }))\n\n    warm_started_linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n    self.assertAllClose(\n        linear_classifier.get_variable_value(\n            AGE_WEIGHT_NAME.replace('age', 'age_in_years')),\n        warm_started_linear_classifier.get_variable_value(AGE_WEIGHT_NAME))\n    # The bias is also warm-started (with no name remapping).\n    self.assertAllClose(\n        linear_classifier.get_variable_value(BIAS_NAME),\n        warm_started_linear_classifier.get_variable_value(BIAS_NAME))\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/metric_keys.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Enum for model prediction keys.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom tensorflow_estimator.python.estimator import model_fn\n\n\nclass MetricKeys(object):\n  \"\"\"Metric key strings.\"\"\"\n  LOSS = model_fn.LOSS_METRIC_KEY\n  LOSS_MEAN = model_fn.AVERAGE_LOSS_METRIC_KEY\n  LOSS_REGULARIZATION = 'regularization_loss'\n\n  ACCURACY = 'accuracy'\n  PRECISION = 'precision'\n  RECALL = 'recall'\n  # This is the best the model could do by always predicting one class.\n  # Should be < ACCURACY in a trained model.\n  ACCURACY_BASELINE = 'accuracy_baseline'\n  AUC = 'auc'\n  AUC_PR = 'auc_precision_recall'\n  LABEL_MEAN = 'label/mean'\n  PREDICTION_MEAN = 'prediction/mean'\n\n  # The following require a threshold applied, should be float in range (0, 1).\n  ACCURACY_AT_THRESHOLD = 'accuracy/positive_threshold_%g'\n  PRECISION_AT_THRESHOLD = 'precision/positive_threshold_%g'\n  RECALL_AT_THRESHOLD = 'recall/positive_threshold_%g'\n\n  # The following require a constraint on a competing metric to be applied,\n  # float in range (0, 1).\n  PRECISION_AT_RECALL = 'precision_at_recall_%g'\n  RECALL_AT_PRECISION = 'recall_at_precision_%g'\n  SENSITIVITY_AT_SPECIFICITY = 'sensitivity_at_specificity_%g'\n  SPECIFICITY_AT_SENSITIVITY = 'specificity_at_sensitivity_%g'\n\n  # The following require a class id applied.\n  PROBABILITY_MEAN_AT_CLASS = 'probability_mean/class%d'\n  AUC_AT_CLASS = 'auc/class%d'\n  AUC_PR_AT_CLASS = 'auc_precision_recall/class%d'\n\n  # The following require a class name applied.\n  PROBABILITY_MEAN_AT_NAME = 'probability_mean/%s'\n  AUC_AT_NAME = 'auc/%s'\n  AUC_PR_AT_NAME = 'auc_precision_recall/%s'\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/optimizers.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Methods related to optimizers used in canned_estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport inspect\nfrom absl import logging\nimport six\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.util import tf_keras\n\n_OPTIMIZER_CLS_NAMES = {\n    'Adagrad': tf.compat.v1.train.AdagradOptimizer,\n    'Adam': tf.compat.v1.train.AdamOptimizer,\n    'Ftrl': tf.compat.v1.train.FtrlOptimizer,\n    'RMSProp': tf.compat.v1.train.RMSPropOptimizer,\n    'SGD': tf.compat.v1.train.GradientDescentOptimizer,\n}\n\n_OPTIMIZER_CLS_NAMES_V2 = {\n    'Adagrad': tf_keras.optimizers.legacy.Adagrad,\n    'Adam': tf_keras.optimizers.legacy.Adam,\n    'Ftrl': tf_keras.optimizers.legacy.Ftrl,\n    'RMSProp': tf_keras.optimizers.legacy.RMSprop,\n    'SGD': tf_keras.optimizers.legacy.SGD,\n}\n\n# The default learning rate of 0.05 is a historical artifact of the initial\n# implementation, but seems a reasonable choice.\n_LEARNING_RATE = 0.05\n\n\ndef get_optimizer_instance(opt, learning_rate=None):\n  \"\"\"Returns an optimizer instance.\n\n  Supports the following types for the given `opt`:\n  * An `Optimizer` instance: Returns the given `opt`.\n  * A string: Creates an `Optimizer` subclass with the given `learning_rate`.\n    Supported strings:\n    * 'Adagrad': Returns an `AdagradOptimizer`.\n    * 'Adam': Returns an `AdamOptimizer`.\n    * 'Ftrl': Returns an `FtrlOptimizer`.\n    * 'RMSProp': Returns an `RMSPropOptimizer`.\n    * 'SGD': Returns a `GradientDescentOptimizer`.\n\n  Args:\n    opt: An `Optimizer` instance, or string, as discussed above.\n    learning_rate: A float. Only used if `opt` is a string.\n\n  Returns:\n    An `Optimizer` instance.\n\n  Raises:\n    ValueError: If `opt` is an unsupported string.\n    ValueError: If `opt` is a supported string but `learning_rate` was not\n      specified.\n    ValueError: If `opt` is none of the above types.\n  \"\"\"\n  if isinstance(opt, six.string_types):\n    if opt in six.iterkeys(_OPTIMIZER_CLS_NAMES):\n      if not learning_rate:\n        raise ValueError('learning_rate must be specified when opt is string.')\n      return _OPTIMIZER_CLS_NAMES[opt](learning_rate=learning_rate)\n    raise ValueError(\n        'Unsupported optimizer name: {}. Supported names are: {}'.format(\n            opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES)))))\n  if callable(opt):\n    opt = opt()\n  if not isinstance(opt, tf.compat.v1.train.Optimizer):\n    raise ValueError(\n        'The given object is not an Optimizer instance. Given: {}'.format(opt))\n  return opt\n\n\ndef _optimizer_has_default_learning_rate(opt):\n  signature = inspect.getfullargspec(opt.__init__)\n  default_name_to_value = dict(zip(signature.args[::-1], signature.defaults))\n  for name in signature.kwonlyargs:\n    if name in signature.kwonlydefaults:\n      default_name_to_value[name] = signature.kwonlydefaults[name]\n  return 'learning_rate' in default_name_to_value\n\n\ndef get_optimizer_instance_v2(opt, learning_rate=None):\n  \"\"\"Returns an optimizer_v2.OptimizerV2 instance.\n\n  Supports the following types for the given `opt`:\n  * An `optimizer_v2.OptimizerV2` instance: Returns the given `opt`.\n  * A string: Creates an `optimizer_v2.OptimizerV2` subclass with the given\n  `learning_rate`.\n    Supported strings:\n    * 'Adagrad': Returns an tf_keras.optimizers.Adagrad.\n    * 'Adam': Returns an tf_keras.optimizers.Adam.\n    * 'Ftrl': Returns an tf_keras.optimizers.Ftrl.\n    * 'RMSProp': Returns an tf_keras.optimizers.RMSProp.\n    * 'SGD': Returns a tf_keras.optimizers.SGD.\n\n  Args:\n    opt: An `tf_keras.optimizers.Optimizer` instance, or string, as discussed\n      above.\n    learning_rate: A float. Only used if `opt` is a string. If None, and opt is\n      string, it will use the default learning_rate of the optimizer.\n\n  Returns:\n    An `tf_keras.optimizers.Optimizer` instance.\n\n  Raises:\n    ValueError: If `opt` is an unsupported string.\n    ValueError: If `opt` is a supported string but `learning_rate` was not\n      specified.\n    ValueError: If `opt` is none of the above types.\n  \"\"\"\n  if isinstance(opt, six.string_types):\n    if opt in six.iterkeys(_OPTIMIZER_CLS_NAMES_V2):\n      if not learning_rate:\n        if _optimizer_has_default_learning_rate(_OPTIMIZER_CLS_NAMES_V2[opt]):\n          return _OPTIMIZER_CLS_NAMES_V2[opt]()\n        else:\n          return _OPTIMIZER_CLS_NAMES_V2[opt](learning_rate=_LEARNING_RATE)\n      return _OPTIMIZER_CLS_NAMES_V2[opt](learning_rate=learning_rate)\n    raise ValueError(\n        'Unsupported optimizer name: {}. Supported names are: {}'.format(\n            opt, tuple(sorted(six.iterkeys(_OPTIMIZER_CLS_NAMES_V2)))))\n  if callable(opt):\n    opt = opt()\n  if isinstance(opt, tf_keras.optimizers.experimental.Optimizer):\n    if tf.executing_eagerly():\n      logging.warning(\n          'You are using `tf_keras.optimizers.experimental.Optimizer` in TF '\n          'estimator, which only supports '\n          '`tf_keras.optimizers.legacy.Optimizer`. Automatically converting '\n          'your optimizer to `tf_keras.optimizers.legacy.Optimizer`.')\n      opt = tf_keras.__internal__.optimizers.convert_to_legacy_optimizer(opt)\n    else:\n      raise ValueError('Please set your optimizer as an instance of '\n                       '`tf_keras.optimizers.legacy.Optimizer`, e.g., '\n                       f'`tf_keras.optimizers.legacy.{opt.__class__.__name__}`.'\n                       f'Received optimizer type: {type(opt)}.')\n  if not isinstance(\n      opt,\n      (tf_keras.optimizers.legacy.Optimizer, tf_keras.optimizers.Optimizer)):\n    raise ValueError(\n        'The given object is not a tf_keras.optimizers.Optimizer instance.'\n        ' Given: {}'.format(opt))\n  return opt\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/optimizers_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for optimizers.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import optimizers\n\n\nclass _TestOptimizer(tf.compat.v1.train.Optimizer):\n\n  def __init__(self):\n    super(_TestOptimizer, self).__init__(\n        use_locking=False, name='TestOptimizer')\n\n\nclass GetOptimizerInstance(tf.test.TestCase):\n\n  def test_unsupported_name(self):\n    with self.assertRaisesRegex(\n        ValueError, 'Unsupported optimizer name: unsupported_name'):\n      optimizers.get_optimizer_instance('unsupported_name', learning_rate=0.1)\n\n  def test_supported_name_but_learning_rate_none(self):\n    with self.assertRaisesRegex(\n        ValueError, 'learning_rate must be specified when opt is string'):\n      optimizers.get_optimizer_instance('Adagrad', learning_rate=None)\n\n  def test_keras_optimizer_after_tf_2_11(self):\n    new_opt = tf_keras.optimizers.Adagrad()\n\n    # In eager mode it should automatically convert to legacy optimizer.\n    opt = optimizers.get_optimizer_instance_v2(new_opt, learning_rate=0.1)\n    self.assertIsInstance(opt, tf_keras.optimizers.legacy.Adagrad)\n\n    # In graph mode errors should be thrown.\n    @tf.function\n    def foo():\n      with self.assertRaisesRegex(\n          ValueError,\n          r'Please set your.*tf_keras\\.optimizers\\.legacy\\.Adagrad.*'):\n        optimizers.get_optimizer_instance_v2(new_opt, learning_rate=0.1)\n    foo()\n\n  def test_adagrad(self):\n    opt = optimizers.get_optimizer_instance('Adagrad', learning_rate=0.1)\n    self.assertIsInstance(opt, tf.compat.v1.train.AdagradOptimizer)\n    self.assertAlmostEqual(0.1, opt._learning_rate)\n\n  def test_adam(self):\n    opt = optimizers.get_optimizer_instance('Adam', learning_rate=0.1)\n    self.assertIsInstance(opt, tf.compat.v1.train.AdamOptimizer)\n    self.assertAlmostEqual(0.1, opt._lr)\n\n  def test_ftrl(self):\n    opt = optimizers.get_optimizer_instance('Ftrl', learning_rate=0.1)\n    self.assertIsInstance(opt, tf.compat.v1.train.FtrlOptimizer)\n    self.assertAlmostEqual(0.1, opt._learning_rate)\n\n  def test_rmsprop(self):\n    opt = optimizers.get_optimizer_instance('RMSProp', learning_rate=0.1)\n    self.assertIsInstance(opt, tf.compat.v1.train.RMSPropOptimizer)\n    self.assertAlmostEqual(0.1, opt._learning_rate)\n\n  def test_sgd(self):\n    opt = optimizers.get_optimizer_instance('SGD', learning_rate=0.1)\n    self.assertIsInstance(opt, tf.compat.v1.train.GradientDescentOptimizer)\n    self.assertAlmostEqual(0.1, opt._learning_rate)\n\n  def test_object(self):\n    opt = optimizers.get_optimizer_instance(_TestOptimizer())\n    self.assertIsInstance(opt, _TestOptimizer)\n\n  def test_object_invalid(self):\n    with self.assertRaisesRegex(\n        ValueError, 'The given object is not an Optimizer instance'):\n      optimizers.get_optimizer_instance((1, 2, 3))\n\n  def test_callable(self):\n\n    def _optimizer_fn():\n      return _TestOptimizer()\n\n    opt = optimizers.get_optimizer_instance(_optimizer_fn)\n    self.assertIsInstance(opt, _TestOptimizer)\n\n  def test_lambda(self):\n    opt = optimizers.get_optimizer_instance(lambda: _TestOptimizer())  # pylint: disable=unnecessary-lambda\n    self.assertIsInstance(opt, _TestOptimizer)\n\n  def test_callable_returns_invalid(self):\n\n    def _optimizer_fn():\n      return (1, 2, 3)\n\n    with self.assertRaisesRegex(\n        ValueError, 'The given object is not an Optimizer instance'):\n      optimizers.get_optimizer_instance(_optimizer_fn)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/optimizers_test_v2.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for optimizers.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import optimizers\n\n\nclass _TestOptimizerV2(tf_keras.optimizers.legacy.Optimizer):\n\n  def __init__(self):\n    super(_TestOptimizerV2, self).__init__(name='TestOptimizer')\n\n  def get_config(self):\n    pass\n\n\nclass GetOptimizerInstanceV2(tf.test.TestCase):\n  \"\"\"Tests for Optimizer V2.\"\"\"\n\n  def test_unsupported_name(self):\n    with self.assertRaisesRegexp(\n        ValueError, 'Unsupported optimizer name: unsupported_name'):\n      optimizers.get_optimizer_instance_v2(\n          'unsupported_name', learning_rate=0.1)\n\n  def test_adagrad_but_no_learning_rate(self):\n    with self.cached_session():\n      opt = optimizers.get_optimizer_instance_v2('Adagrad')\n      # The creation of variables in optimizer_v2 is deferred to when it's\n      # called, so we need to manually create it here. Same for all other tests.\n      self.assertIsInstance(opt.learning_rate, tf.Variable)\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.assertIsInstance(\n          opt,\n          (tf_keras.optimizers.Adagrad, tf_keras.optimizers.legacy.Adagrad))\n      self.assertAlmostEqual(0.001, self.evaluate(opt.learning_rate))\n\n  def test_adam_but_no_learning_rate(self):\n    with self.cached_session():\n      opt = optimizers.get_optimizer_instance_v2('Adam')\n      self.assertIsInstance(opt.learning_rate, tf.Variable)\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.assertIsInstance(\n          opt, (tf_keras.optimizers.Adam, tf_keras.optimizers.legacy.Adam))\n      self.assertAlmostEqual(0.001, self.evaluate(opt.learning_rate))\n\n  def test_adagrad(self):\n    with self.cached_session():\n      opt = optimizers.get_optimizer_instance_v2('Adagrad', learning_rate=0.1)\n      self.assertIsInstance(opt.learning_rate, tf.Variable)\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.assertIsInstance(\n          opt,\n          (tf_keras.optimizers.Adagrad, tf_keras.optimizers.legacy.Adagrad))\n      self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate))\n\n  def test_adam(self):\n    with self.cached_session():\n      opt = optimizers.get_optimizer_instance_v2('Adam', learning_rate=0.1)\n      self.assertIsInstance(opt.learning_rate, tf.Variable)\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.assertIsInstance(\n          opt, (tf_keras.optimizers.Adam, tf_keras.optimizers.legacy.Adam))\n      self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate))\n\n  def test_ftrl(self):\n    with self.cached_session():\n      opt = optimizers.get_optimizer_instance_v2('Ftrl', learning_rate=0.1)\n      self.assertIsInstance(opt.learning_rate, tf.Variable)\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.assertIsInstance(\n          opt, (tf_keras.optimizers.Ftrl, tf_keras.optimizers.legacy.Ftrl))\n      self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate))\n\n  def test_rmsprop(self):\n    with self.cached_session():\n      opt = optimizers.get_optimizer_instance_v2('RMSProp', learning_rate=0.1)\n      self.assertIsInstance(opt.learning_rate, tf.Variable)\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.assertIsInstance(\n          opt,\n          (tf_keras.optimizers.RMSprop, tf_keras.optimizers.legacy.RMSprop))\n      self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate))\n\n  def test_sgd(self):\n    with self.cached_session():\n      opt = optimizers.get_optimizer_instance_v2('SGD', learning_rate=0.1)\n      self.assertIsInstance(opt.learning_rate, tf.Variable)\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.assertIsInstance(\n          opt, (tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))\n      self.assertAlmostEqual(0.1, self.evaluate(opt.learning_rate))\n\n  def test_object(self):\n    opt = optimizers.get_optimizer_instance_v2(_TestOptimizerV2())\n    self.assertIsInstance(opt, _TestOptimizerV2)\n\n  def test_object_invalid(self):\n    with self.assertRaisesRegexp(\n        ValueError,\n        'The given object is not a tf_keras.optimizers.Optimizer instance'):\n      optimizers.get_optimizer_instance_v2((1, 2, 3))\n\n  def test_callable(self):\n\n    def _optimizer_fn():\n      return _TestOptimizerV2()\n\n    opt = optimizers.get_optimizer_instance_v2(_optimizer_fn)\n    self.assertIsInstance(opt, _TestOptimizerV2)\n\n  def test_lambda(self):\n    opt = optimizers.get_optimizer_instance_v2(lambda: _TestOptimizerV2())  # pylint: disable=unnecessary-lambda\n    self.assertIsInstance(opt, _TestOptimizerV2)\n\n  def test_callable_returns_invalid(self):\n\n    def _optimizer_fn():\n      return (1, 2, 3)\n\n    with self.assertRaisesRegexp(\n        ValueError,\n        'The given object is not a tf_keras.optimizers.Optimizer instance'):\n      optimizers.get_optimizer_instance_v2(_optimizer_fn)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/parsing_utils.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Parsing related helper function to be used in `input_fn`.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column_lib as fc\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\n\n@estimator_export('estimator.classifier_parse_example_spec', v1=[])\ndef classifier_parse_example_spec_v2(feature_columns,\n                                     label_key,\n                                     label_dtype=tf.dtypes.int64,\n                                     label_default=None,\n                                     weight_column=None):\n  \"\"\"Generates parsing spec for tf.parse_example to be used with classifiers.\n\n  If users keep data in tf.Example format, they need to call tf.parse_example\n  with a proper feature spec. There are two main things that this utility helps:\n\n  * Users need to combine parsing spec of features with labels and weights\n    (if any) since they are all parsed from same tf.Example instance. This\n    utility combines these specs.\n  * It is difficult to map expected label by a classifier such as\n    `DNNClassifier` to corresponding tf.parse_example spec. This utility encodes\n    it by getting related information from users (key, dtype).\n\n  Example output of parsing spec:\n\n  ```python\n  # Define features and transformations\n  feature_b = tf.feature_column.numeric_column(...)\n  feature_c_bucketized = tf.feature_column.bucketized_column(\n    tf.feature_column.numeric_column(\"feature_c\"), ...)\n  feature_a_x_feature_c = tf.feature_column.crossed_column(\n      columns=[\"feature_a\", feature_c_bucketized], ...)\n\n  feature_columns = [feature_b, feature_c_bucketized, feature_a_x_feature_c]\n  parsing_spec = tf.estimator.classifier_parse_example_spec(\n      feature_columns, label_key='my-label', label_dtype=tf.string)\n\n  # For the above example, classifier_parse_example_spec would return the dict:\n  assert parsing_spec == {\n    \"feature_a\": parsing_ops.VarLenFeature(tf.string),\n    \"feature_b\": parsing_ops.FixedLenFeature([1], dtype=tf.float32),\n    \"feature_c\": parsing_ops.FixedLenFeature([1], dtype=tf.float32)\n    \"my-label\" : parsing_ops.FixedLenFeature([1], dtype=tf.string)\n  }\n  ```\n\n  Example usage with a classifier:\n\n  ```python\n  feature_columns = # define features via tf.feature_column\n  estimator = DNNClassifier(\n      n_classes=1000,\n      feature_columns=feature_columns,\n      weight_column='example-weight',\n      label_vocabulary=['photos', 'keep', ...],\n      hidden_units=[256, 64, 16])\n  # This label configuration tells the classifier the following:\n  # * weights are retrieved with key 'example-weight'\n  # * label is string and can be one of the following ['photos', 'keep', ...]\n  # * integer id for label 'photos' is 0, 'keep' is 1, ...\n\n\n  # Input builders\n  def input_fn_train():  # Returns a tuple of features and labels.\n    features = tf.contrib.learn.read_keyed_batch_features(\n        file_pattern=train_files,\n        batch_size=batch_size,\n        # creates parsing configuration for tf.parse_example\n        features=tf.estimator.classifier_parse_example_spec(\n            feature_columns,\n            label_key='my-label',\n            label_dtype=tf.string,\n            weight_column='example-weight'),\n        reader=tf.RecordIOReader)\n     labels = features.pop('my-label')\n     return features, labels\n\n  estimator.train(input_fn=input_fn_train)\n  ```\n\n  Args:\n    feature_columns: An iterable containing all feature columns. All items\n      should be instances of classes derived from `FeatureColumn`.\n    label_key: A string identifying the label. It means tf.Example stores labels\n      with this key.\n    label_dtype: A `tf.dtype` identifies the type of labels. By default it is\n      `tf.int64`. If user defines a `label_vocabulary`, this should be set as\n      `tf.string`. `tf.float32` labels are only supported for binary\n      classification.\n    label_default: used as label if label_key does not exist in given\n      tf.Example. An example usage: let's say `label_key` is 'clicked' and\n        tf.Example contains clicked data only for positive examples in following\n      format `key:clicked, value:1`. This means that if there is no data with\n        key 'clicked' it should count as negative example by setting\n        `label_deafault=0`. Type of this value should be compatible with\n        `label_dtype`.\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example. If it is a string, it is\n      used as a key to fetch weight tensor from the `features`. If it is a\n      `NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n      weight_column.normalizer_fn is applied on it to get weight tensor.\n\n  Returns:\n    A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`\n    value.\n\n  Raises:\n    ValueError: If label is used in `feature_columns`.\n    ValueError: If weight_column is used in `feature_columns`.\n    ValueError: If any of the given `feature_columns` is not a `_FeatureColumn`\n      instance.\n    ValueError: If `weight_column` is not a `NumericColumn` instance.\n    ValueError: if label_key is None.\n  \"\"\"\n  parsing_spec = tf.compat.v2.feature_column.make_parse_example_spec(feature_columns)\n  label_spec = tf.io.FixedLenFeature((1,), label_dtype, label_default)\n  return _add_label_and_weight_to_parsing_spec(\n      parsing_spec=parsing_spec,\n      label_key=label_key,\n      label_spec=label_spec,\n      weight_column=weight_column)\n\n\n@estimator_export('estimator.regressor_parse_example_spec', v1=[])\ndef regressor_parse_example_spec_v2(feature_columns,\n                                    label_key,\n                                    label_dtype=tf.dtypes.float32,\n                                    label_default=None,\n                                    label_dimension=1,\n                                    weight_column=None):\n  \"\"\"Generates parsing spec for tf.parse_example to be used with regressors.\n\n  If users keep data in tf.Example format, they need to call tf.parse_example\n  with a proper feature spec. There are two main things that this utility helps:\n\n  * Users need to combine parsing spec of features with labels and weights\n    (if any) since they are all parsed from same tf.Example instance. This\n    utility combines these specs.\n  * It is difficult to map expected label by a regressor such as `DNNRegressor`\n    to corresponding tf.parse_example spec. This utility encodes it by getting\n    related information from users (key, dtype).\n\n  Example output of parsing spec:\n\n  ```python\n  # Define features and transformations\n  feature_b = tf.feature_column.numeric_column(...)\n  feature_c_bucketized = tf.feature_column.bucketized_column(\n    tf.feature_column.numeric_column(\"feature_c\"), ...)\n  feature_a_x_feature_c = tf.feature_column.crossed_column(\n      columns=[\"feature_a\", feature_c_bucketized], ...)\n\n  feature_columns = [feature_b, feature_c_bucketized, feature_a_x_feature_c]\n  parsing_spec = tf.estimator.regressor_parse_example_spec(\n      feature_columns, label_key='my-label')\n\n  # For the above example, regressor_parse_example_spec would return the dict:\n  assert parsing_spec == {\n    \"feature_a\": parsing_ops.VarLenFeature(tf.string),\n    \"feature_b\": parsing_ops.FixedLenFeature([1], dtype=tf.float32),\n    \"feature_c\": parsing_ops.FixedLenFeature([1], dtype=tf.float32)\n    \"my-label\" : parsing_ops.FixedLenFeature([1], dtype=tf.float32)\n  }\n  ```\n\n  Example usage with a regressor:\n\n  ```python\n  feature_columns = # define features via tf.feature_column\n  estimator = DNNRegressor(\n      hidden_units=[256, 64, 16],\n      feature_columns=feature_columns,\n      weight_column='example-weight',\n      label_dimension=3)\n  # This label configuration tells the regressor the following:\n  # * weights are retrieved with key 'example-weight'\n  # * label is a 3 dimension tensor with float32 dtype.\n\n\n  # Input builders\n  def input_fn_train():  # Returns a tuple of features and labels.\n    features = tf.contrib.learn.read_keyed_batch_features(\n        file_pattern=train_files,\n        batch_size=batch_size,\n        # creates parsing configuration for tf.parse_example\n        features=tf.estimator.classifier_parse_example_spec(\n            feature_columns,\n            label_key='my-label',\n            label_dimension=3,\n            weight_column='example-weight'),\n        reader=tf.RecordIOReader)\n     labels = features.pop('my-label')\n     return features, labels\n\n  estimator.train(input_fn=input_fn_train)\n  ```\n\n  Args:\n    feature_columns: An iterable containing all feature columns. All items\n      should be instances of classes derived from `_FeatureColumn`.\n    label_key: A string identifying the label. It means tf.Example stores labels\n      with this key.\n    label_dtype: A `tf.dtype` identifies the type of labels. By default it is\n      `tf.float32`.\n    label_default: used as label if label_key does not exist in given\n      tf.Example. By default default_value is none, which means\n      `tf.parse_example` will error out if there is any missing label.\n    label_dimension: Number of regression targets per example. This is the size\n      of the last dimension of the labels and logits `Tensor` objects\n      (typically, these have shape `[batch_size, label_dimension]`).\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example. If it is a string, it is\n      used as a key to fetch weight tensor from the `features`. If it is a\n      `NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n      weight_column.normalizer_fn is applied on it to get weight tensor.\n\n  Returns:\n    A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`\n    value.\n\n  Raises:\n    ValueError: If label is used in `feature_columns`.\n    ValueError: If weight_column is used in `feature_columns`.\n    ValueError: If any of the given `feature_columns` is not a `_FeatureColumn`\n      instance.\n    ValueError: If `weight_column` is not a `NumericColumn` instance.\n    ValueError: if label_key is None.\n  \"\"\"\n  parsing_spec = tf.compat.v2.feature_column.make_parse_example_spec(feature_columns)\n  label_spec = tf.io.FixedLenFeature((label_dimension,), label_dtype,\n                                     label_default)\n  return _add_label_and_weight_to_parsing_spec(\n      parsing_spec=parsing_spec,\n      label_key=label_key,\n      label_spec=label_spec,\n      weight_column=weight_column)\n\n\ndef _add_label_and_weight_to_parsing_spec(parsing_spec,\n                                          label_key,\n                                          label_spec,\n                                          weight_column=None):\n  \"\"\"Adds label and weight spec to given parsing spec.\n\n  Args:\n    parsing_spec: A dict mapping each feature key to a `FixedLenFeature` or\n      `VarLenFeature` to which label and weight spec are added.\n    label_key: A string identifying the label. It means tf.Example stores labels\n      with this key.\n    label_spec: A `FixedLenFeature`.\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example. If it is a string, it is\n      used as a key to fetch weight tensor from the `features`. If it is a\n      `NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n      weight_column.normalizer_fn is applied on it to get weight tensor.\n\n  Returns:\n    A dict mapping each feature key to a `FixedLenFeature` or `VarLenFeature`\n      value.\n  \"\"\"\n  if label_key in parsing_spec:\n    raise ValueError('label should not be used as feature. '\n                     'label_key: {}, features: {}'.format(\n                         label_key, parsing_spec.keys()))\n  parsing_spec[label_key] = label_spec\n\n  if weight_column is None:\n    return parsing_spec\n\n  if isinstance(weight_column, six.string_types):\n    weight_column = tf.feature_column.numeric_column(weight_column)\n\n  if not isinstance(weight_column, fc.NumericColumn):\n    raise ValueError('weight_column should be an instance of '\n                     'tf.feature_column.numeric_column. '\n                     'Given type: {} value: {}'.format(\n                         type(weight_column), weight_column))\n\n  if weight_column.key in parsing_spec:\n    raise ValueError('weight_column should not be used as feature. '\n                     'weight_column: {}, features: {}'.format(\n                         weight_column.key, parsing_spec.keys()))\n\n  parsing_spec.update(weight_column.parse_example_spec)\n  return parsing_spec\n\n\n@estimator_export(v1=['estimator.classifier_parse_example_spec'])\ndef classifier_parse_example_spec(feature_columns,\n                                  label_key,\n                                  label_dtype=tf.dtypes.int64,\n                                  label_default=None,\n                                  weight_column=None):\n  parsing_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n      feature_columns)\n  label_spec = tf.io.FixedLenFeature((1,), label_dtype, label_default)\n  return _add_label_and_weight_to_parsing_spec(\n      parsing_spec=parsing_spec,\n      label_key=label_key,\n      label_spec=label_spec,\n      weight_column=weight_column)\n\n\nclassifier_parse_example_spec.__doc__ = classifier_parse_example_spec_v2.__doc__\n\n\n@estimator_export(v1=['estimator.regressor_parse_example_spec'])\ndef regressor_parse_example_spec(\n    feature_columns,  # pylint: disable=missing-docstring\n    label_key,\n    label_dtype=tf.dtypes.float32,\n    label_default=None,\n    label_dimension=1,\n    weight_column=None):\n  parsing_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n      feature_columns)\n  label_spec = tf.io.FixedLenFeature((label_dimension,), label_dtype,\n                                     label_default)\n  return _add_label_and_weight_to_parsing_spec(\n      parsing_spec=parsing_spec,\n      label_key=label_key,\n      label_spec=label_spec,\n      weight_column=weight_column)\n\n\nregressor_parse_example_spec.__doc__ = regressor_parse_example_spec_v2.__doc__\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/parsing_utils_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for parsing_utils.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.canned import parsing_utils\n\n\nclass BaseClassifierParseExampleSpec(object):\n  \"\"\"Tests tf.estimator.classifier_parse_example_spec.\"\"\"\n\n  def __init__(self, parse_example_fn):\n    self._parse_example_fn = parse_example_fn\n\n  def test_defaults(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')], label_key='b')\n    expected_spec = {\n        'a': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.int64),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_string(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')],\n        label_key='b',\n        label_dtype=tf.dtypes.string)\n    expected_spec = {\n        'a': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.string),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  # TODO(ispir): test label_default_value compatibility with label_dtype\n  def test_label_default_value(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')],\n        label_key='b',\n        label_default=0)\n    expected_spec = {\n        'a':\n            tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b':\n            tf.io.FixedLenFeature((1,), dtype=tf.dtypes.int64, default_value=0),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_weight_column_as_string(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')],\n        label_key='b',\n        weight_column='c')\n    expected_spec = {\n        'a': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.int64),\n        'c': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_weight_column_as_numeric_column(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')],\n        label_key='b',\n        weight_column=tf.feature_column.numeric_column('c'))\n    expected_spec = {\n        'a': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.int64),\n        'c': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_label_key_should_not_be_used_as_feature(self):\n    with self.assertRaisesRegexp(ValueError,\n                                 'label should not be used as feature'):\n      self._parse_example_fn(\n          feature_columns=[tf.feature_column.numeric_column('a')],\n          label_key='a')\n\n  def test_weight_column_should_not_be_used_as_feature(self):\n    with self.assertRaisesRegexp(ValueError,\n                                 'weight_column should not be used as feature'):\n      self._parse_example_fn(\n          feature_columns=[tf.feature_column.numeric_column('a')],\n          label_key='b',\n          weight_column=tf.feature_column.numeric_column('a'))\n\n  def test_weight_column_should_be_a_numeric_column(self):\n    with self.assertRaisesRegexp(ValueError,\n                                 'tf.feature_column.numeric_column'):\n      not_a_numeric_column = 3\n      self._parse_example_fn(\n          feature_columns=[tf.feature_column.numeric_column('a')],\n          label_key='b',\n          weight_column=not_a_numeric_column)\n\n\nclass ClassifierParseExampleSpecV2(BaseClassifierParseExampleSpec,\n                                   tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    BaseClassifierParseExampleSpec.__init__(\n        self, parsing_utils.classifier_parse_example_spec_v2)\n\n  def test_non_v1_feature_column(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.sequence_numeric_column('a')],\n        label_key='b')\n    expected_spec = {\n        'a': tf.io.VarLenFeature(dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.int64),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n\nclass ClassifierParseExampleSpecV1(BaseClassifierParseExampleSpec,\n                                   tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    BaseClassifierParseExampleSpec.__init__(\n        self, parsing_utils.classifier_parse_example_spec)\n\n\nclass BaseRegressorParseExampleSpec(object):\n  \"\"\"Tests tf.estimator.classifier_parse_example_spec.\"\"\"\n\n  def __init__(self, parse_example_fn):\n    self._parse_example_fn = parse_example_fn\n\n  def test_defaults(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')], label_key='b')\n    expected_spec = {\n        'a': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_int64(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')],\n        label_key='b',\n        label_dtype=tf.dtypes.int64)\n    expected_spec = {\n        'a': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.int64),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_label_default_value(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')],\n        label_key='b',\n        label_default=0.)\n    expected_spec = {\n        'a':\n            tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b':\n            tf.io.FixedLenFeature((1,),\n                                  dtype=tf.dtypes.float32,\n                                  default_value=0.),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_label_dimension(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')],\n        label_key='b',\n        label_dimension=3)\n    expected_spec = {\n        'a': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((3,), dtype=tf.dtypes.float32),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_weight_column_as_string(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')],\n        label_key='b',\n        weight_column='c')\n    expected_spec = {\n        'a': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'c': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_weight_column_as_numeric_column(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.numeric_column('a')],\n        label_key='b',\n        weight_column=tf.feature_column.numeric_column('c'))\n    expected_spec = {\n        'a': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n        'c': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n  def test_label_key_should_not_be_used_as_feature(self):\n    with self.assertRaisesRegexp(ValueError,\n                                 'label should not be used as feature'):\n      self._parse_example_fn(\n          feature_columns=[tf.feature_column.numeric_column('a')],\n          label_key='a')\n\n  def test_weight_column_should_not_be_used_as_feature(self):\n    with self.assertRaisesRegexp(ValueError,\n                                 'weight_column should not be used as feature'):\n      self._parse_example_fn(\n          feature_columns=[tf.feature_column.numeric_column('a')],\n          label_key='b',\n          weight_column=tf.feature_column.numeric_column('a'))\n\n  def test_weight_column_should_be_a_numeric_column(self):\n    with self.assertRaisesRegexp(ValueError,\n                                 'tf.feature_column.numeric_column'):\n      not_a_numeric_column = 3\n      self._parse_example_fn(\n          feature_columns=[tf.feature_column.numeric_column('a')],\n          label_key='b',\n          weight_column=not_a_numeric_column)\n\n\nclass RegressorParseExampleSpecV2(BaseRegressorParseExampleSpec,\n                                  tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    BaseRegressorParseExampleSpec.__init__(\n        self, parsing_utils.regressor_parse_example_spec_v2)\n\n  def test_non_v1_feature_column(self):\n    parsing_spec = self._parse_example_fn(\n        feature_columns=[tf.feature_column.sequence_numeric_column('a')],\n        label_key='b')\n    expected_spec = {\n        'a': tf.io.VarLenFeature(dtype=tf.dtypes.float32),\n        'b': tf.io.FixedLenFeature((1,), dtype=tf.dtypes.float32),\n    }\n    self.assertDictEqual(expected_spec, parsing_spec)\n\n\nclass RegressorParseExampleSpecV1(BaseRegressorParseExampleSpec,\n                                  tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    BaseRegressorParseExampleSpec.__init__(\n        self, parsing_utils.regressor_parse_example_spec)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/prediction_keys.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Enum for model prediction keys.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n\nclass PredictionKeys(object):\n  \"\"\"Enum for canonical model prediction keys.\n\n  The following values are defined:\n  PREDICTIONS: Used by models that predict values, such as regressor models.\n  \"\"\"\n\n  CLASSES = 'classes'\n  CLASS_IDS = 'class_ids'\n  ALL_CLASSES = 'all_classes'\n  ALL_CLASS_IDS = 'all_class_ids'\n  LOGISTIC = 'logistic'\n  LOGITS = 'logits'\n  PREDICTIONS = 'predictions'\n  PROBABILITIES = 'probabilities'\n  TOP_K = 'top_k'\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/rnn.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Recurrent Neural Network model and estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column_lib as fc\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\nfrom tensorflow_estimator.python.estimator.canned import optimizers\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.head import binary_class_head as binary_head_lib\nfrom tensorflow_estimator.python.estimator.head import multi_class_head as multi_head_lib\nfrom tensorflow_estimator.python.estimator.head import sequential_head as seq_head_lib\n\n# The defaults are historical artifacts of the initial implementation, but seem\n# reasonable choices.\n# TODO(aarg): Also apply default learning rate and clipping to Keras model so\n# they apply when the optimizer is set via `compile` and the model trained via\n# the `fit` method.\n_DEFAULT_LEARNING_RATE = 0.05\n_DEFAULT_CLIP_NORM = 5.0\n\n_SIMPLE_RNN_KEY = 'simple_rnn'\n_LSTM_KEY = 'lstm'\n_GRU_KEY = 'gru'\n\n_CELL_TYPE_TO_LAYER_MAPPING = {\n    _LSTM_KEY: tf_keras.layers.LSTM,\n    _GRU_KEY: tf_keras.layers.GRU,\n    _SIMPLE_RNN_KEY: tf_keras.layers.SimpleRNN\n}\n\n_CELL_TYPES = {\n    _LSTM_KEY: tf_keras.layers.LSTMCell,\n    _GRU_KEY: tf_keras.layers.GRUCell,\n    _SIMPLE_RNN_KEY: tf_keras.layers.SimpleRNNCell\n}\n\n# Indicates no value was provided by the user to a kwarg.\nUSE_DEFAULT = object()\n\n\ndef _single_rnn_cell(units, cell_type):\n  \"\"\"Initializes a RNN cell.\"\"\"\n  cell_type = _CELL_TYPES.get(cell_type, cell_type)\n  if not callable(cell_type):\n    raise ValueError(\n        '`cell_type` should be a class producing a RNN cell, or a string '\n        'specifying the cell type. Supported strings are: {}.'.format(\n            [_SIMPLE_RNN_KEY, _LSTM_KEY, _GRU_KEY]))\n  cell = cell_type(units=units)\n  if hasattr(cell, '_enable_caching_device'):\n    # Enable the caching_device to speed up the repeative varaible read in\n    # tf.while. This should work only with tf.session.\n    cell._enable_caching_device = True  # pylint: disable=protected-access\n  if not hasattr(cell, 'call') or not hasattr(cell, 'state_size'):\n    raise ValueError('RNN cell should have a `call` and `state_size` method.')\n  return cell\n\n\ndef _make_rnn_cell_fn(units, cell_type=_SIMPLE_RNN_KEY):\n  \"\"\"Convenience function to create `rnn_cell_fn` for canned RNN Estimators.\n\n  Args:\n    units: Iterable of integer number of hidden units per RNN layer.\n    cell_type: A class producing a RNN cell or a string specifying the cell\n      type. Supported strings are: `'simple_rnn'`, `'lstm'`, and `'gru'`.\n\n  Returns:\n    A function that returns a RNN cell.\n\n  Raises:\n    ValueError: If cell_type is not supported.\n  \"\"\"\n\n  def rnn_cell_fn():\n    cells = [_single_rnn_cell(n, cell_type) for n in units]\n    if len(cells) == 1:\n      return cells[0]\n    return cells\n\n  return rnn_cell_fn\n\n\nclass RNNModel(tf_keras.models.Model):\n  \"\"\"A Keras RNN model.\n\n  Composition of layers to compute logits from RNN model, along with training\n  and inference features. See `tf_keras.models.Model` for more details on Keras\n  models.\n\n  Example of usage:\n\n  ```python\n  rating = tf.feature_column.embedding_column(\n      tf.feature_column.sequence_categorical_column_with_identity('rating', 5),\n      10)\n  rnn_layer = tf_keras.layers.SimpleRNN(20)\n  rnn_model = RNNModel(rnn_layer, units=1, sequence_feature_columns=[rating])\n\n  rnn_model.compile(\n      tf_keras.optimizers.Adam(), loss=tf_keras.losses.MeanSquaredError())\n  rnn_model.fit(generator(), epochs=10, steps_per_epoch=100)\n  rnn_model.predict({'rating': np.array([[0, 1], [2, 3]])}, steps=1)\n  ```\n  \"\"\"\n\n  # TODO(aarg): Update arguments to support multiple rnn layers.\n  def __init__(self,\n               rnn_layer,\n               units,\n               sequence_feature_columns,\n               context_feature_columns=None,\n               activation=None,\n               return_sequences=False,\n               **kwargs):\n    \"\"\"Initializes a RNNModel instance.\n\n    Args:\n      rnn_layer: A Keras RNN layer.\n      units: An int indicating the dimension of the logit layer, and of the\n        model output.\n      sequence_feature_columns: An iterable containing the `FeatureColumn`s that\n        represent sequential input. All items in the set should either be\n        sequence columns (e.g. `sequence_numeric_column`) or constructed from\n        one (e.g. `embedding_column` with `sequence_categorical_column_*` as\n        input).\n      context_feature_columns: An iterable containing the `FeatureColumn`s for\n        contextual input. The data represented by these columns will be\n        replicated and given to the RNN at each timestep. These columns must be\n        instances of classes derived from `DenseColumn` such as\n        `numeric_column`, not the sequential variants.\n      activation: Activation function to apply to the logit layer (for instance\n        `tf_keras.activations.sigmoid`). If you don't specify anything, no\n        activation is applied.\n      return_sequences: A boolean indicating whether to return the last output\n        in the output sequence, or the full sequence.\n      **kwargs: Additional arguments.\n\n    Raises:\n      ValueError: If `units` is not an int.\n    \"\"\"\n    super(RNNModel, self).__init__(**kwargs)\n    if not isinstance(units, int):\n      raise ValueError('units must be an int.  Given type: {}'.format(\n          type(units)))\n    self._return_sequences = return_sequences\n    self._sequence_feature_columns = sequence_feature_columns\n    self._context_feature_columns = context_feature_columns\n    self._sequence_features_layer = tf_keras.experimental.SequenceFeatures(\n        sequence_feature_columns)\n    self._dense_features_layer = None\n    if context_feature_columns:\n      self._dense_features_layer = tf_keras_v1.layers.DenseFeatures(\n          context_feature_columns)\n    self._rnn_layer = rnn_layer\n    self._logits_layer = tf_keras.layers.Dense(\n        units=units, activation=activation, name='logits')\n\n  def call(self, inputs, training=None):\n    \"\"\"Computes the RNN output.\n\n    By default no activation is applied and the logits are returned. To output\n    probabilites an activation needs to be specified such as sigmoid or softmax.\n\n    Args:\n      inputs: A dict mapping keys to input tensors.\n      training: Python boolean indicating whether the layers should behave in\n        training mode or in inference mode. This argument is passed to the\n        model's layers. This is for instance used with cells that use dropout.\n\n    Returns:\n      A `Tensor` with logits from RNN model. It has shape\n      (batch_size, time_step, logits_size) if `return_sequence` is `True`,\n      (batch_size, logits_size) otherwise.\n    \"\"\"\n    if not isinstance(inputs, dict):\n      raise ValueError('inputs should be a dictionary of `Tensor`s. '\n                       'Given type: {}'.format(type(inputs)))\n    with ops.name_scope('sequence_input_layer'):\n      try:\n        sequence_input, sequence_length = self._sequence_features_layer(\n            inputs, training=training)\n      except TypeError:\n        sequence_input, sequence_length = self._sequence_features_layer(inputs)\n      tf.compat.v1.summary.histogram('sequence_length', sequence_length)\n\n      if self._context_feature_columns:\n        try:\n          context_input = self._dense_features_layer(inputs, training=training)\n        except TypeError:\n          context_input = self._dense_features_layer(inputs)\n        sequence_input = fc.concatenate_context_input(\n            context_input, sequence_input=sequence_input)\n\n    sequence_length_mask = tf.sequence_mask(sequence_length)\n    rnn_outputs = self._rnn_layer(\n        sequence_input, mask=sequence_length_mask, training=training)\n\n    logits = self._logits_layer(rnn_outputs)\n    if self._return_sequences:\n      # Passes sequence mask as `_keras_mask` to be used in Keras model for\n      # loss and metrics aggregation to exclude padding in the sequential case.\n      logits._keras_mask = sequence_length_mask  # pylint: disable=protected-access\n    return logits\n\n  def get_config(self):\n    \"\"\"Returns a dictionary with the config of the model.\"\"\"\n    config = {'name': self.name}\n    config['rnn_layer'] = {\n        'class_name': self._rnn_layer.__class__.__name__,\n        'config': self._rnn_layer.get_config()\n    }\n    config['units'] = self._logits_layer.units\n    config['return_sequences'] = self._return_sequences\n    config['activation'] = tf_keras.activations.serialize(self._logits_layer.activation)\n    config['sequence_feature_columns'] = fc.serialize_feature_columns(\n        self._sequence_feature_columns)\n    config['context_feature_columns'] = (\n        fc.serialize_feature_columns(self._context_feature_columns)\n        if self._context_feature_columns else None)\n    return config\n\n  @classmethod\n  def from_config(cls, config, custom_objects=None):\n    \"\"\"Creates a RNNModel from its config.\n\n    Args:\n      config: A Python dictionary, typically the output of `get_config`.\n      custom_objects: Optional dictionary mapping names (strings) to custom\n        classes or functions to be considered during deserialization.\n\n    Returns:\n      A RNNModel.\n    \"\"\"\n    rnn_layer = tf_keras.layers.deserialize(\n        config.pop('rnn_layer'), custom_objects=custom_objects)\n    sequence_feature_columns = fc.deserialize_feature_columns(\n        config.pop('sequence_feature_columns'), custom_objects=custom_objects)\n    context_feature_columns = config.pop('context_feature_columns', None)\n    if context_feature_columns:\n      context_feature_columns = fc.deserialize_feature_columns(\n          context_feature_columns, custom_objects=custom_objects)\n    activation = tf_keras.activations.deserialize(\n        config.pop('activation', None), custom_objects=custom_objects)\n    return cls(\n        rnn_layer=rnn_layer,\n        sequence_feature_columns=sequence_feature_columns,\n        context_feature_columns=context_feature_columns,\n        activation=activation,\n        **config)\n\n\ndef _get_rnn_estimator_spec(features, labels, mode, head, rnn_model, optimizer,\n                            return_sequences):\n  \"\"\"Computes `EstimatorSpec` from logits to use in estimator model function.\n\n  Args:\n    features: dict of `Tensor` and `SparseTensor` objects returned from\n      `input_fn`.\n    labels: `Tensor` of shape [batch_size, 1] or [batch_size] with labels.\n    mode: Defines whether this is training, evaluation or prediction. See\n      `ModeKeys`.\n    head: A `Head` instance.\n    rnn_model: A Keras model that computes RNN logits from features.\n    optimizer: String, `tf_keras.optimizers.Optimizer` object, or callable that\n      creates the optimizer to use for training. If not specified, will use the\n      Adagrad optimizer with a default learning rate of 0.05 and gradient clip\n      norm of 5.0.\n    return_sequences: A boolean indicating whether to return the last output in\n      the output sequence, or the full sequence.\n\n  Returns:\n    An `EstimatorSpec` instance.\n\n  Raises:\n    ValueError: If mode or optimizer is invalid, or features has the wrong type.\n  \"\"\"\n  training = (mode == model_fn.ModeKeys.TRAIN)\n  # In TRAIN mode, create optimizer and assign global_step variable to\n  # optimizer.iterations to make global_step increased correctly, as Hooks\n  # relies on global step as step counter - otherwise skip optimizer\n  # initialization and set it to None.\n  if training:\n    # If user does not provide an optimizer instance, use the optimizer\n    # specified by the string with default learning rate and gradient clipping.\n    if isinstance(optimizer, six.string_types):\n      optimizer = optimizers.get_optimizer_instance_v2(\n          optimizer, learning_rate=_DEFAULT_LEARNING_RATE)\n      optimizer.clipnorm = _DEFAULT_CLIP_NORM\n    else:\n      optimizer = optimizers.get_optimizer_instance_v2(optimizer)\n    optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()\n  else:\n    optimizer = None\n\n  logits = rnn_model(features, training)\n\n  if return_sequences and head.input_sequence_mask_key not in features:\n    features[head.input_sequence_mask_key] = logits._keras_mask  # pylint: disable=protected-access\n\n  return head.create_estimator_spec(\n      features=features,\n      mode=mode,\n      labels=labels,\n      optimizer=optimizer,\n      logits=logits,\n      update_ops=rnn_model.updates,\n      trainable_variables=rnn_model.trainable_variables)\n\n\ndef _verify_rnn_cell_input(rnn_cell_fn, units, cell_type):\n  if rnn_cell_fn and (units or cell_type != USE_DEFAULT):\n    raise ValueError(\n        'units and cell_type must not be specified when using rnn_cell_fn')\n\n\ndef _make_rnn_layer(rnn_cell_fn, units, cell_type, return_sequences):\n  \"\"\"Assert arguments are valid and return rnn_layer_fn.\n\n  Args:\n    rnn_cell_fn: A function that returns a RNN cell instance that will be used\n      to construct the RNN.\n    units: Iterable of integer number of hidden units per RNN layer.\n    cell_type: A class producing a RNN cell or a string specifying the cell\n      type.\n    return_sequences: A boolean indicating whether to return the last output\n      in the output sequence, or the full sequence.:\n\n  Returns:\n    A tf_keras.layers.RNN layer.\n  \"\"\"\n  _verify_rnn_cell_input(rnn_cell_fn, units, cell_type)\n  if cell_type in _CELL_TYPE_TO_LAYER_MAPPING and isinstance(units, int):\n    return _CELL_TYPE_TO_LAYER_MAPPING[cell_type](\n        units=units, return_sequences=return_sequences)\n  if not rnn_cell_fn:\n    if cell_type == USE_DEFAULT:\n      cell_type = _SIMPLE_RNN_KEY\n    rnn_cell_fn = _make_rnn_cell_fn(units, cell_type)\n\n  return tf_keras.layers.RNN(cell=rnn_cell_fn(), return_sequences=return_sequences)\n\n\n@estimator_export('estimator.experimental.RNNEstimator', v1=[])\nclass RNNEstimator(estimator.Estimator):\n  \"\"\"An Estimator for TensorFlow RNN models with user-specified head.\n\n  Example:\n\n  ```python\n  token_sequence = sequence_categorical_column_with_hash_bucket(...)\n  token_emb = embedding_column(categorical_column=token_sequence, ...)\n\n  estimator = RNNEstimator(\n      head=tf.estimator.RegressionHead(),\n      sequence_feature_columns=[token_emb],\n      units=[32, 16], cell_type='lstm')\n\n  # Or with custom RNN cell:\n  def rnn_cell_fn(_):\n    cells = [ tf_keras.layers.LSTMCell(size) for size in [32, 16] ]\n    return tf_keras.layers.StackedRNNCells(cells)\n\n  estimator = RNNEstimator(\n      head=tf.estimator.RegressionHead(),\n      sequence_feature_columns=[token_emb],\n      rnn_cell_fn=rnn_cell_fn)\n\n  # Input builders\n  def input_fn_train: # returns x, y\n    pass\n  estimator.train(input_fn=input_fn_train, steps=100)\n\n  def input_fn_eval: # returns x, y\n    pass\n  metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)\n  def input_fn_predict: # returns x, None\n    pass\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n  otherwise there will be a `KeyError`:\n\n  * if the head's `weight_column` is not `None`, a feature with\n    `key=weight_column` whose value is a `Tensor`.\n  * for each `column` in `sequence_feature_columns`:\n    - a feature with `key=column.name` whose `value` is a `SparseTensor`.\n  * for each `column` in `context_feature_columns`:\n    - if `column` is a `CategoricalColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedCategoricalColumn`, two features: the first\n      with `key` the id column name, the second with `key` the weight column\n      name. Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `DenseColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss and predicted output are determined by the specified head.\n\n  @compatibility(eager)\n  Estimators are not compatible with eager execution.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               head,\n               sequence_feature_columns,\n               context_feature_columns=None,\n               units=None,\n               cell_type=USE_DEFAULT,\n               rnn_cell_fn=None,\n               return_sequences=False,\n               model_dir=None,\n               optimizer='Adagrad',\n               config=None):\n    \"\"\"Initializes a `RNNEstimator` instance.\n\n    Args:\n      head: A `Head` instance. This specifies the model's output and loss\n        function to be optimized.\n      sequence_feature_columns: An iterable containing the `FeatureColumn`s that\n        represent sequential input. All items in the set should either be\n        sequence columns (e.g. `sequence_numeric_column`) or constructed from\n        one (e.g. `embedding_column` with `sequence_categorical_column_*` as\n        input).\n      context_feature_columns: An iterable containing the `FeatureColumn`s for\n        contextual input. The data represented by these columns will be\n        replicated and given to the RNN at each timestep. These columns must be\n        instances of classes derived from `DenseColumn` such as\n        `numeric_column`, not the sequential variants.\n      units: Iterable of integer number of hidden units per RNN layer. If set,\n        `cell_type` must also be specified and `rnn_cell_fn` must be `None`.\n      cell_type: A class producing a RNN cell or a string specifying the cell\n        type. Supported strings are: `'simple_rnn'`, `'lstm'`, and `'gru'`. If\n          set, `units` must also be specified and `rnn_cell_fn` must be `None`.\n      rnn_cell_fn: A function that returns a RNN cell instance that will be used\n        to construct the RNN. If set, `units` and `cell_type` cannot be set.\n        This is for advanced users who need additional customization beyond\n        `units` and `cell_type`. Note that `tf_keras.layers.StackedRNNCells` is\n        needed for stacked RNNs.\n      return_sequences: A boolean indicating whether to return the last output\n        in the output sequence, or the full sequence.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      optimizer: An instance of `tf.Optimizer` or string specifying optimizer\n        type. Defaults to Adagrad optimizer.\n      config: `RunConfig` object to configure the runtime settings.\n\n    Note that a RNN cell has:\n      - a `call` method.\n      - a `state_size` attribute.\n      - a `output_size` attribute.\n      - a `get_initial_state` method.\n\n    See the documentation on `tf_keras.layers.RNN` for more details.\n\n    Raises:\n      ValueError: If `units`, `cell_type`, and `rnn_cell_fn` are not\n        compatible.\n    \"\"\"\n\n    # TODO(aarg): Instead of raising an error convert head to sequential head.\n    if return_sequences and not isinstance(head, seq_head_lib._SequentialHead):  # pylint: disable=protected-access\n      raise ValueError('Provided head must be a `_SequentialHead` object when '\n                       '`return_sequences` is set to True.')\n    _verify_rnn_cell_input(rnn_cell_fn, units, cell_type)\n\n    def _model_fn(features, labels, mode, config):\n      \"\"\"RNNEstimator model function.\"\"\"\n      del config  # Unused.\n      rnn_layer = _make_rnn_layer(\n          rnn_cell_fn=rnn_cell_fn,\n          units=units,\n          cell_type=cell_type,\n          return_sequences=return_sequences)\n      rnn_model = RNNModel(\n          rnn_layer=rnn_layer,\n          units=head.logits_dimension,\n          sequence_feature_columns=sequence_feature_columns,\n          context_feature_columns=context_feature_columns,\n          return_sequences=return_sequences,\n          name='rnn_model')\n      return _get_rnn_estimator_spec(\n          features,\n          labels,\n          mode,\n          head=head,\n          rnn_model=rnn_model,\n          optimizer=optimizer,\n          return_sequences=return_sequences)\n\n    super(RNNEstimator, self).__init__(\n        model_fn=_model_fn, model_dir=model_dir, config=config)\n\n\n@estimator_export('estimator.experimental.RNNClassifier', v1=[])\nclass RNNClassifier(RNNEstimator):\n  \"\"\"A classifier for TensorFlow RNN models.\n\n  Trains a recurrent neural network model to classify instances into one of\n  multiple classes.\n\n  Example:\n\n  ```python\n  token_sequence = sequence_categorical_column_with_hash_bucket(...)\n  token_emb = embedding_column(categorical_column=token_sequence, ...)\n\n  estimator = RNNClassifier(\n      sequence_feature_columns=[token_emb],\n      units=[32, 16], cell_type='lstm')\n\n  # Input builders\n  def input_fn_train: # returns x, y\n    pass\n  estimator.train(input_fn=input_fn_train, steps=100)\n\n  def input_fn_eval: # returns x, y\n    pass\n  metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)\n  def input_fn_predict: # returns x, None\n    pass\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n\n  Input of `train` and `evaluate` should have following features,\n  otherwise there will be a `KeyError`:\n\n  * if `weight_column` is not `None`, a feature with\n    `key=weight_column` whose value is a `Tensor`.\n  * for each `column` in `sequence_feature_columns`:\n    - a feature with `key=column.name` whose `value` is a `SparseTensor`.\n  * for each `column` in `context_feature_columns`:\n    - if `column` is a `CategoricalColumn`, a feature with `key=column.name`\n      whose `value` is a `SparseTensor`.\n    - if `column` is a `WeightedCategoricalColumn`, two features: the first\n      with `key` the id column name, the second with `key` the weight column\n      name. Both features' `value` must be a `SparseTensor`.\n    - if `column` is a `DenseColumn`, a feature with `key=column.name`\n      whose `value` is a `Tensor`.\n\n  Loss is calculated by using softmax cross entropy.\n\n  @compatibility(eager)\n  Estimators are not compatible with eager execution.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               sequence_feature_columns,\n               context_feature_columns=None,\n               units=None,\n               cell_type=USE_DEFAULT,\n               rnn_cell_fn=None,\n               return_sequences=False,\n               model_dir=None,\n               n_classes=2,\n               weight_column=None,\n               label_vocabulary=None,\n               optimizer='Adagrad',\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               sequence_mask='sequence_mask',\n               config=None):\n    \"\"\"Initializes a `RNNClassifier` instance.\n\n    Args:\n      sequence_feature_columns: An iterable containing the `FeatureColumn`s that\n        represent sequential input. All items in the set should either be\n        sequence columns (e.g. `sequence_numeric_column`) or constructed from\n        one (e.g. `embedding_column` with `sequence_categorical_column_*` as\n        input).\n      context_feature_columns: An iterable containing the `FeatureColumn`s for\n        contextual input. The data represented by these columns will be\n        replicated and given to the RNN at each timestep. These columns must be\n        instances of classes derived from `DenseColumn` such as\n        `numeric_column`, not the sequential variants.\n      units: Iterable of integer number of hidden units per RNN layer. If set,\n        `cell_type` must also be specified and `rnn_cell_fn` must be `None`.\n      cell_type: A class producing a RNN cell or a string specifying the cell\n        type. Supported strings are: `'simple_rnn'`, `'lstm'`, and `'gru'`. If\n          set, `units` must also be specified and `rnn_cell_fn` must be `None`.\n      rnn_cell_fn: A function that returns a RNN cell instance that will be used\n        to construct the RNN. If set, `units` and `cell_type` cannot be set.\n        This is for advanced users who need additional customization beyond\n        `units` and `cell_type`. Note that `tf_keras.layers.StackedRNNCells` is\n        needed for stacked RNNs.\n      return_sequences: A boolean indicating whether to return the last output\n        in the output sequence, or the full sequence. Note that if True,\n        `weight_column` must be None or a string.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      n_classes: Number of label classes. Defaults to 2, namely binary\n        classification. Must be > 1.\n      weight_column: A string or a `NumericColumn` created by\n        `tf.feature_column.numeric_column` defining feature column representing\n        weights. It is used to down weight or boost examples during training. It\n        will be multiplied by the loss of the example. If it is a string, it is\n        used as a key to fetch weight tensor from the `features`. If it is a\n        `NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n        weight_column.normalizer_fn is applied on it to get weight tensor.\n      label_vocabulary: A list of strings represents possible label values. If\n        given, labels must be string type and have any value in\n        `label_vocabulary`. If it is not given, that means labels are already\n        encoded as integer or float within [0, 1] for `n_classes=2` and encoded\n        as integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also\n        there will be errors if vocabulary is not provided and labels are\n        string.\n      optimizer: An instance of `tf.Optimizer` or string specifying optimizer\n        type. Defaults to Adagrad optimizer.\n      loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how\n        to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n      sequence_mask: A string with the name of the sequence mask tensor. If\n        `sequence_mask` is in the features dictionary, the provided tensor is\n        used, otherwise the sequence mask is computed from the length of\n        sequential features. The sequence mask is used in evaluation and\n        training mode to aggregate loss and metrics computation while excluding\n        padding steps. It is also added to the predictions dictionary in\n        prediction mode to indicate which steps are padding.\n      config: `RunConfig` object to configure the runtime settings.\n\n    Note that a RNN cell has:\n      - a `call` method.\n      - a `state_size` attribute.\n      - a `output_size` attribute.\n      - a `get_initial_state` method.\n\n    See the documentation on `tf_keras.layers.RNN` for more details.\n\n    Raises:\n      ValueError: If `units`, `cell_type`, and `rnn_cell_fn` are not\n        compatible.\n    \"\"\"\n    if n_classes == 2:\n      head = binary_head_lib.BinaryClassHead(\n          weight_column=weight_column,\n          label_vocabulary=label_vocabulary,\n          loss_reduction=loss_reduction)\n    else:\n      head = multi_head_lib.MultiClassHead(\n          n_classes=n_classes,\n          weight_column=weight_column,\n          label_vocabulary=label_vocabulary,\n          loss_reduction=loss_reduction)\n\n    if return_sequences:\n      tf.compat.v1.logging.info(\n          'Converting head to sequential head with '\n          '`SequentialHeadWrapper` to allow sequential predictions.')\n      head = seq_head_lib.SequentialHeadWrapper(\n          head,\n          sequence_length_mask=sequence_mask,\n          feature_columns=weight_column)\n\n    super(RNNClassifier, self).__init__(\n        head=head,\n        sequence_feature_columns=sequence_feature_columns,\n        context_feature_columns=context_feature_columns,\n        units=units,\n        cell_type=cell_type,\n        rnn_cell_fn=rnn_cell_fn,\n        return_sequences=return_sequences,\n        model_dir=model_dir,\n        optimizer=optimizer,\n        config=config)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/rnn_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for rnn.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport random\nimport tempfile\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import parsing_utils\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.canned import rnn\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.head import multi_class_head as multi_head_lib\nfrom tensorflow_estimator.python.estimator.head import sequential_head as seq_head_lib\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n# Names of variables created by BasicRNNCell model.\nCELL_KERNEL_NAME = 'rnn_model/rnn/kernel'\nCELL_RECURRENT_KERNEL_NAME = 'rnn_model/rnn/recurrent_kernel'\nCELL_BIAS_NAME = 'rnn_model/rnn/bias'\nLOGITS_WEIGHTS_NAME = 'rnn_model/logits/kernel'\nLOGITS_BIAS_NAME = 'rnn_model/logits/bias'\n\n\ndef _assert_close(expected, actual, rtol=1e-04, name='assert_close'):\n  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:\n    expected = ops.convert_to_tensor(expected, name='expected')\n    actual = ops.convert_to_tensor(actual, name='actual')\n    rdiff = tf.math.abs(expected - actual, 'diff') / tf.math.abs(expected)\n    rtol = ops.convert_to_tensor(rtol, name='rtol')\n    return tf.compat.v1.debugging.assert_less(\n        rdiff,\n        rtol,\n        data=('Condition expected =~ actual did not hold element-wise:'\n              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,\n              'rtol = ', rtol,),\n        name=scope)\n\n\ndef create_checkpoint(kernel, recurrent, bias, dense_kernel, dense_bias,\n                      global_step, model_dir):\n  \"\"\"Create checkpoint file with provided model weights.\n\n  Args:\n    kernel: Iterable of values of input weights for the RNN cell.\n    recurrent: Iterable of values of recurrent weights for the RNN cell.\n    bias: Iterable of values of biases for the RNN cell.\n    dense_kernel: Iterable of values for matrix connecting RNN output to logits.\n    dense_bias: Iterable of values for logits bias term.\n    global_step: Initial global step to save in checkpoint.\n    model_dir: Directory into which checkpoint is saved.\n  \"\"\"\n  model_weights = {}\n  model_weights[CELL_KERNEL_NAME] = kernel\n  model_weights[CELL_RECURRENT_KERNEL_NAME] = recurrent\n  model_weights[CELL_BIAS_NAME] = bias\n  model_weights[LOGITS_WEIGHTS_NAME] = dense_kernel\n  model_weights[LOGITS_BIAS_NAME] = dense_bias\n\n  with tf.Graph().as_default():\n    # Create model variables.\n    for k, v in six.iteritems(model_weights):\n      tf.Variable(v, name=k, dtype=tf.dtypes.float32)\n\n    # Create non-model variables.\n    global_step_var = tf.compat.v1.train.create_global_step()\n    assign_op = global_step_var.assign(global_step)\n\n    # Initialize vars and save checkpoint.\n    with tf.compat.v1.train.MonitoredTrainingSession(\n        checkpoint_dir=model_dir) as sess:\n      sess.run(assign_op)\n\n\ndef _make_rnn_layer(rnn_cell_fn=None,\n                    units=None,\n                    cell_type=rnn.USE_DEFAULT,\n                    return_sequences=False):\n  return rnn._make_rnn_layer(\n      rnn_cell_fn=rnn_cell_fn,\n      units=units,\n      cell_type=cell_type,\n      return_sequences=return_sequences)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RNNLayerFnTest(tf.test.TestCase, parameterized.TestCase):\n  \"\"\"Tests for rnn layer function.\"\"\"\n\n  def testWrongClassProvided(self):\n    \"\"\"Tests that an error is raised if the class doesn't have a call method.\"\"\"\n    with self.assertRaisesRegexp(\n        ValueError, 'RNN cell should have a `call` and `state_size` method.'):\n      _make_rnn_layer(units=[10], cell_type=lambda units: object())\n\n  def testWrongStringProvided(self):\n    \"\"\"Tests that an error is raised if cell type is unknown.\"\"\"\n    with self.assertRaisesRegexp(\n        ValueError,\n        'cell_type` should be a class producing a RNN cell, or a string .*.'):\n      _make_rnn_layer(units=[10], cell_type='unknown-cell-name')\n\n  @parameterized.parameters(['simple_rnn', rnn.USE_DEFAULT])\n  def testDefaultCellProvided(self, cell_type):\n    \"\"\"Tests behavior when the default cell type is provided.\"\"\"\n    layer = _make_rnn_layer(cell_type=cell_type, units=[1])\n    self.assertIsInstance(layer, tf_keras.layers.RNN)\n    self.assertIsInstance(layer.cell, tf_keras.layers.SimpleRNNCell)\n\n  @parameterized.parameters([('gru', tf_keras.layers.GRU),\n                             ('lstm', tf_keras.layers.LSTM),\n                             ('simple_rnn', tf_keras.layers.SimpleRNN)])\n  def testSpecificLayerTypeProvided(self, cell_type, layer_type):\n    \"\"\"Tests specific layer type for GRU and LSTM.\"\"\"\n    layer = _make_rnn_layer(cell_type=cell_type, units=1)\n    self.assertIsInstance(layer, layer_type)\n\n  def testSpecificLayerTypeArguments(self):\n    \"\"\"Tests arguments for specific layer types (GRU and LSTM).\"\"\"\n    mock_layer_type = tf.compat.v1.test.mock.Mock()\n    with tf.compat.v1.test.mock.patch.object(rnn, '_CELL_TYPE_TO_LAYER_MAPPING',\n                                             {'custom-type': mock_layer_type}):\n      _make_rnn_layer(\n          cell_type='custom-type',\n          units=11,\n          return_sequences='return-seq-value')\n      mock_layer_type.assert_called_once_with(\n          units=11, return_sequences='return-seq-value')\n\n  @tf.compat.v1.test.mock.patch.object(tf_keras.layers, 'RNN')\n  def testCustomCellProvided(self, mock_rnn_layer_type):\n    \"\"\"Tests behavior when a custom cell type is provided.\"\"\"\n    mock_custom_cell = tf.compat.v1.test.mock.Mock()\n    _make_rnn_layer(\n        units=[10],\n        cell_type=lambda units: mock_custom_cell,\n        return_sequences='return-seq-value')\n    mock_rnn_layer_type.assert_called_once_with(\n        cell=mock_custom_cell, return_sequences='return-seq-value')\n\n  def testMultipleCellsProvided(self):\n    \"\"\"Tests behavior when multiple cells are provided.\"\"\"\n    layer = _make_rnn_layer(cell_type='simple_rnn', units=[1, 2])\n    self.assertIsInstance(layer, tf_keras.layers.RNN)\n    self.assertIsInstance(layer.cell, tf_keras.layers.StackedRNNCells)\n    self.assertLen(layer.cell.cells, 2)\n    self.assertIsInstance(layer.cell.cells[0], tf_keras.layers.SimpleRNNCell)\n\n  @tf.compat.v1.test.mock.patch.object(tf_keras.layers, 'RNN')\n  def testCustomCellFnProvided(self, mock_rnn_layer_type):\n    \"\"\"Tests behavior when a custom cell function is provided.\"\"\"\n    mock_cell_fn = tf.compat.v1.test.mock.Mock(return_value='custom-cell')\n    _make_rnn_layer(\n        rnn_cell_fn=mock_cell_fn, return_sequences='return-seq-value')\n    mock_rnn_layer_type.assert_called_once_with(\n        cell='custom-cell', return_sequences='return-seq-value')\n\n\ndef _mock_logits_layer(kernel, bias):\n  \"\"\"Sets initialization values to dense `logits` layers used in context.\"\"\"\n\n  class _MockDenseLayer(tf_keras.layers.Dense):\n\n    def __init__(self, units, activation, name):\n      kwargs = {}\n      if name == 'logits':\n        kwargs = {\n            'kernel_initializer': tf.compat.v1.initializers.constant(kernel),\n            'bias_initializer': tf.compat.v1.initializers.constant(bias)\n        }\n\n      super(_MockDenseLayer, self).__init__(\n          units=units, name=name, activation=activation, **kwargs)\n\n  return tf.compat.v1.test.mock.patch.object(tf_keras.layers, 'Dense',\n                                             _MockDenseLayer)\n\n\ndef _default_features_fn():\n  return {\n      'price':\n          tf.sparse.SparseTensor(\n              values=[10., 5.], indices=[[0, 0], [0, 1]], dense_shape=[1, 2]),\n  }\n\n\ndef _get_mock_head():\n  mock_head = multi_head_lib.MultiClassHead(3)\n  mock_head.create_estimator_spec = tf.compat.v1.test.mock.Mock(\n      return_value=model_fn.EstimatorSpec(None))\n  return mock_head\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RNNLogitFnTest(tf.test.TestCase, parameterized.TestCase):\n  \"\"\"Tests correctness of logits calculated from RNNModel.\"\"\"\n\n  def setUp(self):\n    # Sets layers default weights for testing purpose.\n    self.kernel = [[.1, -.2]]\n    self.recurrent = [[.2, -.3], [.3, -.4]]\n    self.bias = [.2, .5]\n    self.dense_kernel = [[-1.], [1.]]\n    self.dense_bias = [0.3]\n    self.sequence_feature_columns = [\n        tf.feature_column.sequence_numeric_column('price', shape=(1,))\n    ]\n    self.context_feature_columns = []\n    super(RNNLogitFnTest, self).setUp()\n\n  def _mock_logits_layer(self):\n    return _mock_logits_layer(self.dense_kernel, bias=self.dense_bias)\n\n  def _test_logits(self,\n                   logits_dimension,\n                   features_fn,\n                   expected_logits,\n                   expected_mask,\n                   return_sequences=False,\n                   training=False):\n    \"\"\"Tests that the expected logits are calculated.\"\"\"\n    rnn_layer = tf_keras.layers.SimpleRNN(\n        2,\n        return_sequences=return_sequences,\n        kernel_initializer=tf.compat.v1.initializers.constant(self.kernel),\n        recurrent_initializer=tf.compat.v1.initializers.constant(\n            self.recurrent),\n        bias_initializer=tf.compat.v1.initializers.constant(self.bias))\n    with self._mock_logits_layer():\n      logit_layer = rnn.RNNModel(\n          rnn_layer=rnn_layer,\n          units=logits_dimension,\n          sequence_feature_columns=self.sequence_feature_columns,\n          context_feature_columns=self.context_feature_columns,\n          return_sequences=return_sequences)\n    logits = logit_layer(features_fn(), training=training)\n    if return_sequences:\n      logits = (logits, logits._keras_mask)\n      expected_logits = (expected_logits, expected_mask)\n    self.evaluate(tf.compat.v1.initializers.global_variables())\n    self.assertAllClose(expected_logits, self.evaluate(logits), atol=1e-4)\n\n  @parameterized.named_parameters(\n      {\n          'testcase_name': 'Static',\n          'return_sequences': False,\n          'expected_logits': [[-0.6033]]\n      }, {\n          'testcase_name': 'Sequential',\n          'return_sequences': True,\n          'expected_logits': [[[-1.4388], [-0.6033]]]\n      }, {\n          'testcase_name': 'SequentialTrain',\n          'return_sequences': True,\n          'expected_logits': [[[-1.4388], [-0.6033]]],\n          'training': True\n      }, {\n          'testcase_name': 'SequentialInfer',\n          'return_sequences': True,\n          'expected_logits': [[[-1.4388], [-0.6033]]],\n          'training': False\n      })\n  def testOneDimLogits(self, return_sequences, expected_logits, training=False):\n    \"\"\"Tests one-dimensional logits.\n\n    Intermediate values are rounded for ease in reading.\n    input_layer = [[[10]], [[5]]]\n    sequence_mask = [[1, 1]]\n    initial_state = [0, 0]\n    rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),\n                              tanh(-.2*10 - .3*0 - .4*0 +.5)]]\n                          = [[0.83, -0.91]]\n    rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),\n                              tanh(-.2*5 - .3*.83 + .4*.91 +.5)]]\n                          = [[0.53, -0.37]]\n    logits_timestep_1 = [[-1*0.83 - 1*0.91 + 0.3]] = [[-1.4388]]\n    logits_timestep_2 = [[-1*0.53 - 1*0.37 + 0.3]] = [[-0.6033]]\n\n    Args:\n      return_sequences: A boolean indicating whether to return the last output\n        in the output sequence, or the full sequence.\n      expected_logits: An array with expected logits result.\n      training: Specifies if this training or evaluation / prediction mode.\n    \"\"\"\n    expected_mask = [[1, 1]]\n\n    self._test_logits(\n        logits_dimension=1,\n        features_fn=_default_features_fn,\n        expected_mask=expected_mask,\n        expected_logits=expected_logits,\n        return_sequences=return_sequences,\n        training=training)\n\n  @parameterized.named_parameters(\n      {\n          'testcase_name': 'Static',\n          'return_sequences': False,\n          'expected_logits': [[-0.6033, 0.7777, 0.5698]]\n      }, {\n          'testcase_name': 'Sequential',\n          'return_sequences': True,\n          'expected_logits': [[[-1.4388, 1.0884, 0.5762],\n                               [-0.6033, 0.7777, 0.5698]]]\n      })\n  def testMultiDimLogits(self, return_sequences, expected_logits):\n    \"\"\"Tests multi-dimensional logits.\n\n    Intermediate values are rounded for ease in reading.\n    input_layer = [[[10]], [[5]]]\n    sequence_mask = [[1, 1]]\n    initial_state = [0, 0]\n    rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),\n                              tanh(-.2*10 - .3*0 - .4*0 +.5)]]\n                          = [[0.83, -0.91]]\n    rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),\n                              tanh(-.2*5 - .3*.83 + .4*.91 +.5)]]\n                          = [[0.53, -0.37]]\n    logits_timestep_1 = [[-1*0.83 - 1*0.91 + 0.3],\n                         [0.5*0.83 + 0.3*0.91 + 0.4],\n                         [0.2*0.83 - 0.1*0.91 + 0.5]]\n                      = [[-1.4388, 1.0884, 0.5762]]\n    logits_timestep_2 = [[-1*0.53 - 1*0.37 + 0.3],\n                         [0.5*0.53 + 0.3*0.37 + 0.4],\n                         [0.2*0.53 - 0.1*0.37 + 0.5]]\n                      = [[-0.6033, 0.7777, 0.5698]]\n\n    Args:\n      return_sequences: A boolean indicating whether to return the last output\n        in the output sequence, or the full sequence.\n      expected_logits: An array with expected logits result.\n    \"\"\"\n    expected_mask = [[1, 1]]\n\n    self.dense_kernel = [[-1., 0.5, 0.2], [1., -0.3, 0.1]]\n    self.dense_bias = [0.3, 0.4, 0.5]\n    self._test_logits(\n        logits_dimension=3,\n        features_fn=_default_features_fn,\n        expected_mask=expected_mask,\n        expected_logits=expected_logits,\n        return_sequences=return_sequences)\n\n  @parameterized.named_parameters(\n      {\n          'testcase_name': 'Static',\n          'return_sequences': False,\n          'expected_logits': [[-0.6033, 0.7777, 0.5698],\n                              [-1.2473, 1.0170, 0.5745]]\n      }, {\n          'testcase_name': 'Sequential',\n          'return_sequences': True,\n          'expected_logits': [[\n              [-1.4388, 1.0884, 0.5762], [-0.6033, 0.7777, 0.5698]\n          ], [[0.0197, 0.5601, 0.5860], [-1.2473, 1.0170, 0.5745]]]\n      })\n  def testMultiExampleMultiDim(self, return_sequences, expected_logits):\n    \"\"\"Tests multiple examples and multi-dimensional logits.\n\n    Intermediate values are rounded for ease in reading.\n    input_layer = [[[10], [5]], [[2], [7]]]\n    sequence_mask = [[1, 1], [1, 1]]\n    initial_state = [[0, 0], [0, 0]]\n    rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),\n                              tanh(-.2*10 - .3*0 - .4*0 +.5)],\n                             [tanh(.1*2 + .2*0 + .3*0 +.2),\n                              tanh(-.2*2 - .3*0 - .4*0 +.5)]]\n                          = [[0.83, -0.91], [0.38, 0.10]]\n    rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),\n                              tanh(-.2*5 - .3*.83 + .4*.91 +.5)],\n                             [tanh(.1*7 + .2*.38 + .3*.10 +.2),\n                              tanh(-.2*7 - .3*.38 - .4*.10 +.5)]]\n                          = [[0.53, -0.37], [0.76, -0.78]\n    logits_timestep_1 = [[-1*0.83 - 1*0.91 + 0.3,\n                          0.5*0.83 + 0.3*0.91 + 0.4,\n                          0.2*0.83 - 0.1*0.91 + 0.5],\n                         [-1*0.38 + 1*0.10 + 0.3,\n                          0.5*0.38 - 0.3*0.10 + 0.4,\n                          0.2*0.38 + 0.1*0.10 + 0.5]]\n                      = [[-1.4388, 1.0884, 0.5762], [0.0197, 0.5601, 0.5860]]\n    logits_timestep_2 = [[-1*0.53 - 1*0.37 + 0.3,\n                          0.5*0.53 + 0.3*0.37 + 0.4,\n                          0.2*0.53 - 0.1*0.37 + 0.5],\n                         [-1*0.76 - 1*0.78 + 0.3,\n                          0.5*0.76 +0.3*0.78 + 0.4,\n                          0.2*0.76 -0.1*0.78 + 0.5]]\n                      = [[-0.6033, 0.7777, 0.5698], [-1.2473, 1.0170, 0.5745]]\n\n    Args:\n      return_sequences: A boolean indicating whether to return the last output\n        in the output sequence, or the full sequence.\n      expected_logits: An array with expected logits result.\n    \"\"\"\n    expected_mask = [[1, 1], [1, 1]]\n\n    def features_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2., 7.],\n                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],\n                  dense_shape=[2, 2]),\n      }\n\n    self.dense_kernel = [[-1., 0.5, 0.2], [1., -0.3, 0.1]]\n    self.dense_bias = [0.3, 0.4, 0.5]\n    self._test_logits(\n        logits_dimension=3,\n        features_fn=features_fn,\n        expected_mask=expected_mask,\n        expected_logits=expected_logits,\n        return_sequences=return_sequences)\n\n  @parameterized.named_parameters(\n      {\n          'testcase_name': 'Static',\n          'return_sequences': False,\n          'expected_logits': [[-0.6033], [0.0197]]\n      }, {\n          'testcase_name': 'Sequential',\n          'return_sequences': True,\n          'expected_logits': [[[-1.4388], [-0.6033]], [[0.0197], [0.0197]]]\n      })\n  def testMultiExamplesDifferentLength(self, return_sequences, expected_logits):\n    \"\"\"Tests multiple examples with different lengths.\n\n    Intermediate values are rounded for ease in reading.\n    input_layer = [[[10], [5]], [[2], [0]]]\n    sequence_mask = [[1, 1], [1, 0]]\n    initial_state = [[0, 0], [0, 0]]\n    rnn_output_timestep_1 = [[tanh(.1*10 + .2*0 + .3*0 +.2),\n                              tanh(-.2*10 - .3*0 - .4*0 +.5)],\n                             [tanh(.1*2 + .2*0 + .3*0 +.2),\n                              tanh(-.2*2 - .3*0 - .4*0 +.5)]]\n                          = [[0.83, -0.91], [0.38, 0.10]]\n    rnn_output_timestep_2 = [[tanh(.1*5 + .2*.83 - .3*.91 +.2),\n                              tanh(-.2*5 - .3*.83 + .4*.91 +.5)],\n                             [_]]\n                          = [[0.53, -0.37], [_, _]]\n    logits_timestep_1 = [[-1*0.83 - 1*0.91 + 0.3],\n                         [-1*0.38 + 1*0.10 + 0.3]]\n                      = [[-0.4388], [0.0197]]\n    logits_timestep_2 = [[-1*0.53 - 1*0.37 + 0.3],\n                         [_]]\n                      = [[-0.6033], [_]]\n\n    Args:\n      return_sequences: A boolean indicating whether to return the last output\n        in the output sequence, or the full sequence.\n      expected_logits: An array with expected logits result.\n    \"\"\"\n    expected_mask = [[1, 1], [1, 0]]\n\n    def features_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2.],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2]),\n      }\n\n    self._test_logits(\n        logits_dimension=1,\n        features_fn=features_fn,\n        expected_mask=expected_mask,\n        expected_logits=expected_logits,\n        return_sequences=return_sequences)\n\n  def testMultiExamplesWithContext(self):\n    \"\"\"Tests multiple examples with context features.\n\n    Intermediate values are rounded for ease in reading.\n    input_layer = [[[10, -0.5], [5, -0.5]], [[2, 0.8], [0, 0]]]\n    sequence_mask = [[1, 1], [1, 0]]\n    initial_state = [[0, 0], [0, 0]]\n    rnn_output_timestep_1 = [[tanh(.1*10 - 1*.5 + .2*0 + .3*0 +.2),\n                              tanh(-.2*10 - 0.9*.5 - .3*0 - .4*0 +.5)],\n                             [tanh(.1*2 + 1*.8 + .2*0 + .3*0 +.2),\n                              tanh(-.2*2 + .9*.8 - .3*0 - .4*0 +.5)]]\n                          = [[0.60, -0.96], [0.83, 0.68]]\n    rnn_output_timestep_2 = [[tanh(.1*5 - 1*.5 + .2*.60 - .3*.96 +.2),\n                              tanh(-.2*5 - .9*.5 - .3*.60 + .4*.96 +.5)],\n                             [<ignored-padding>]]\n                          = [[0.03, -0.63], [<ignored-padding>]]\n    logits = [[-1*0.03 - 1*0.63 + 0.3],\n              [-1*0.83 + 1*0.68 + 0.3]]\n           = [[-0.3662], [0.1414]]\n    \"\"\"\n    expected_mask = [[1, 1], [1, 0]]\n\n    def features_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2.],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2]),\n          'context': [[-0.5], [0.8]],\n      }\n\n    self.context_feature_columns = [\n        tf.feature_column.numeric_column('context', shape=(1,))\n    ]\n\n    self.kernel = [[.1, -.2], [1., 0.9]]\n    self._test_logits(\n        logits_dimension=1,\n        features_fn=features_fn,\n        expected_mask=expected_mask,\n        expected_logits=[[-0.3662], [0.1414]])\n\n  def testMultiExamplesMultiFeatures(self):\n    \"\"\"Tests examples with multiple sequential feature columns.\n\n    Intermediate values are rounded for ease in reading.\n    input_layer = [[[1, 0, 10], [0, 1, 5]], [[1, 0, 2], [0, 0, 0]]]\n    sequence_mask = [[1, 1], [1, 0]]\n    initial_state = [[0, 0], [0, 0]]\n    rnn_output_timestep_1 = [[tanh(.5*1 + 1*0 + .1*10 + .2*0 + .3*0 +.2),\n                              tanh(-.5*1 - 1*0 - .2*10 - .3*0 - .4*0 +.5)],\n                             [tanh(.5*1 + 1*0 + .1*2 + .2*0 + .3*0 +.2),\n                              tanh(-.5*1 - 1*0 - .2*2 - .3*0 - .4*0 +.5)]]\n                          = [[0.94, -0.96], [0.72, -0.38]]\n    rnn_output_timestep_2 = [[tanh(.5*0 + 1*1 + .1*5 + .2*.94 - .3*.96 +.2),\n                              tanh(-.5*0 - 1*1 - .2*5 - .3*.94 + .4*.96 +.5)],\n                             [<ignored-padding>]]\n                          = [[0.92, -0.88], [<ignored-padding>]]\n    logits = [[-1*0.92 - 1*0.88 + 0.3],\n              [-1*0.72 - 1*0.38 + 0.3]]\n           = [[-1.5056], [-0.7962]]\n    \"\"\"\n    expected_mask = [[1, 1], [1, 0]]\n\n    def features_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2.],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2]),\n          'on_sale':\n              tf.sparse.SparseTensor(\n                  values=[0, 1, 0],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2]),\n      }\n\n    price_column = tf.feature_column.sequence_numeric_column(\n        'price', shape=(1,))\n    on_sale_column = tf.feature_column.indicator_column(\n        tf.feature_column.sequence_categorical_column_with_identity(\n            'on_sale', num_buckets=2))\n    self.sequence_feature_columns = [price_column, on_sale_column]\n\n    self.kernel = [[.5, -.5], [1., -1.], [.1, -.2]]\n    self._test_logits(\n        logits_dimension=1,\n        features_fn=features_fn,\n        expected_mask=expected_mask,\n        expected_logits=[[-1.5056], [-0.7962]])\n\n  @parameterized.parameters([(model_fn.ModeKeys.TRAIN, True),\n                             (model_fn.ModeKeys.EVAL, False),\n                             (model_fn.ModeKeys.PREDICT, False)])\n  def testTrainingMode(self, mode, expected_training_mode):\n    \"\"\"Tests that `training` argument is properly used.\"\"\"\n\n    class _MockRNNCell(tf_keras.layers.SimpleRNNCell):\n      \"\"\"Used to test that `training` argument is properly used.\"\"\"\n\n      def __init__(self, test_case):\n        self._test_case = test_case\n        super(_MockRNNCell, self).__init__(units=10)\n\n      def call(self, inputs, states, training=None):\n        self._test_case.assertEqual(training, expected_training_mode)\n        return super(_MockRNNCell, self).call(\n            inputs=inputs, states=states, training=training)\n\n    estimator = rnn.RNNEstimator(\n        head=_get_mock_head(),\n        rnn_cell_fn=lambda: _MockRNNCell(self),\n        sequence_feature_columns=self.sequence_feature_columns)\n    features = {\n        'price':\n            tf.sparse.SparseTensor(\n                values=[\n                    10.,\n                ], indices=[[0, 0]], dense_shape=[1, 1]),\n    }\n    estimator.model_fn(features=features, labels=None, mode=mode, config=None)\n\n\nclass RNNModelTest(tf.test.TestCase, parameterized.TestCase):\n  \"\"\"Tests for RNNModel.\"\"\"\n\n  def setUp(self):\n    super(RNNModelTest, self).setUp()\n    self.kernel = [[.1, -.2]]\n    self.recurrent = [[.2, -.3], [.3, -.4]]\n    self.bias = [.2, .5]\n    self.dense_kernel = [[-1.], [1.]]\n    self.dense_bias = [0.3]\n    self.sequence_feature_columns = [\n        tf.feature_column.sequence_numeric_column('price', shape=(1,))\n    ]\n    self.x = {\n        'price':\n            tf.sparse.SparseTensor(\n                values=[10., 5., 2.],\n                indices=[[0, 0], [0, 1], [1, 0]],\n                dense_shape=[2, 2]),\n    }\n    self.y = ops.convert_to_tensor([[[0], [1]], [[0], [1]]])\n\n  def _get_compiled_model(self,\n                          return_sequences=False,\n                          optimizer='Adam',\n                          **kwargs):\n    \"\"\"Initializes and compiles a RNN model with specific weights.\"\"\"\n    rnn_layer = tf_keras.layers.SimpleRNN(\n        2,\n        return_sequences=return_sequences,\n        kernel_initializer=tf.compat.v1.initializers.constant(self.kernel),\n        recurrent_initializer=tf.compat.v1.initializers.constant(\n            self.recurrent),\n        bias_initializer=tf.compat.v1.initializers.constant(self.bias))\n    with _mock_logits_layer(self.dense_kernel, bias=self.dense_bias):\n      model = rnn.RNNModel(\n          units=1,\n          rnn_layer=rnn_layer,\n          sequence_feature_columns=self.sequence_feature_columns,\n          activation=tf_keras.activations.sigmoid,\n          return_sequences=return_sequences,\n          **kwargs)\n      model.compile(\n          optimizer=optimizer,\n          loss=tf_keras.losses.BinaryCrossentropy(reduction='sum'),\n          metrics=['accuracy'])\n    return model\n\n  def testModelWeights(self):\n    \"\"\"Tests that the layers weights are properly added to the model weights.\"\"\"\n    col = tf.feature_column.categorical_column_with_hash_bucket(\n        'tokens', hash_bucket_size=1)\n    context_feature_columns = [\n        tf.feature_column.embedding_column(col, dimension=1)\n    ]\n    seq_col = tf.feature_column.sequence_categorical_column_with_hash_bucket(\n        'seq-tokens', hash_bucket_size=1)\n    sequence_feature_columns = [\n        tf.feature_column.embedding_column(seq_col, dimension=1)\n    ]\n    model = rnn.RNNModel(\n        units=1,\n        rnn_layer=tf_keras.layers.SimpleRNN(2),\n        sequence_feature_columns=sequence_feature_columns,\n        activation=tf_keras.activations.sigmoid,\n        context_feature_columns=context_feature_columns)\n    model.compile(\n        optimizer='Adam',\n        loss=tf_keras.losses.BinaryCrossentropy(reduction='sum'),\n        metrics=['accuracy'])\n\n    model.predict(\n        x={\n            'tokens': ops.convert_to_tensor([['a']]),\n            'seq-tokens': ops.convert_to_tensor([[['a']]])\n        },\n        steps=1)\n    # Weights included are:\n    # - recurrent, kernel and bias from RNN layer\n    # - kernel and bias from logits layer\n    # - sequential feature column embedding\n    # - context feature column embedding.\n    self.assertLen(model.get_weights(), 7)\n\n  def _testModelConfig(self, **kwargs):\n    \"\"\"Tests the parameters of a RNNModel stored to and restored from config.\n\n    Args:\n      **kwargs: Additional keyword arguments to initialize the RNNModel before\n        calling `get_config`.\n\n    Returns:\n      A dictionary with RNNModel initialization arguments from the `from_config`\n      call.\n    \"\"\"\n    seq_col = tf.feature_column.sequence_categorical_column_with_hash_bucket(\n        'seq-tokens', hash_bucket_size=1)\n    sequence_feature_columns = [\n        tf.feature_column.embedding_column(\n            seq_col, dimension=1, initializer=tf.compat.v1.initializers.zeros())\n    ]\n    model = rnn.RNNModel(\n        units=11,\n        rnn_layer=tf_keras.layers.SimpleRNN(3),\n        sequence_feature_columns=sequence_feature_columns,\n        return_sequences=True,\n        name='rnn-model',\n        **kwargs)\n\n    with tf.compat.v1.test.mock.patch.object(\n        rnn.RNNModel, '__init__', return_value=None) as init:\n      rnn.RNNModel.from_config(\n          model.get_config(),\n          custom_objects={'Zeros': tf.compat.v1.initializers.zeros})\n      return list(init.call_args_list[0])[1]\n\n  def testModelConfig(self):\n    \"\"\"Tests that a RNNModel can be stored to and restored from config.\"\"\"\n    init_kwargs = self._testModelConfig()\n    self.assertEqual(init_kwargs['name'], 'rnn-model')\n    self.assertEqual(init_kwargs['units'], 11)\n    self.assertEqual(init_kwargs['return_sequences'], True)\n    self.assertEqual(\n        init_kwargs['sequence_feature_columns'][0].categorical_column.name,\n        'seq-tokens')\n    self.assertEqual(init_kwargs['context_feature_columns'], None)\n    self.assertEqual(init_kwargs['activation'].__name__, 'linear')\n    self.assertEqual(init_kwargs['rnn_layer'].cell.units, 3)\n\n  def testModelConfigWithActivation(self):\n    \"\"\"Tests store / restore from config with logits activation.\"\"\"\n    init_kwargs = self._testModelConfig(activation=tf_keras.activations.sigmoid)\n    self.assertEqual(init_kwargs['activation'].__name__, 'sigmoid')\n\n  def testModelConfigWithContextFeatures(self):\n    \"\"\"Tests store / restore from config with context features.\"\"\"\n    init_kwargs = self._testModelConfig(context_feature_columns=[\n        tf.feature_column.numeric_column('context', shape=(1,))\n    ])\n    self.assertEqual(init_kwargs['context_feature_columns'][0].name, 'context')\n\n  def DISABLED_testSaveModelWeights(self):  # See b/129842600.\n    \"\"\"Tests that model weights can be saved and restored.\"\"\"\n    model = self._get_compiled_model(return_sequences=True)\n    model.fit(x=self.x, y=self.y, batch_size=1, steps_per_epoch=1, epochs=1)\n    y1 = model.predict(x=self.x, steps=1)\n    model.save_weights(self.get_temp_dir() + 'model')\n\n    model = self._get_compiled_model(return_sequences=True, name='model-2')\n    model.load_weights(self.get_temp_dir() + 'model')\n    y2 = model.predict(x=self.x, steps=1)\n    self.assertAllClose(y1, y2)\n\n  def DISABLED_testEvaluationMetrics(self):  # See b/129842600.\n    \"\"\"Tests evaluation metrics computation in non-sequential case.\"\"\"\n    model = self._get_compiled_model()\n    metrics = model.evaluate(\n        x=self.x, y=ops.convert_to_tensor([[0], [1]]), steps=1)\n    # See `RNNClassifierEvaluationTest` for details on computation.\n    self.assertAllClose(metrics, (1.1196611, 1.), atol=1e-4)\n\n  def DISABLED_testEvaluationSequential(self):  # See b/129842600.\n    \"\"\"Tests that the sequence mask is properly used to aggregate loss.\"\"\"\n    model = self._get_compiled_model(return_sequences=True)\n    metrics = model.evaluate(x=self.x, y=self.y, steps=1)\n    # See `RNNClassifierEvaluationTest` for details on computation.\n    self.assertAllClose(metrics, (1.9556, 1. / 3.), atol=1e-4)\n\n  def DISABLED_testPredictions(self):  # See b/129842600.\n    \"\"\"Tests predictions with RNN model.\"\"\"\n    model = self._get_compiled_model()\n    # See `RNNClassifierPredictionTest` for details on computation.\n    self.assertAllClose(\n        model.predict(x=self.x, steps=1), [[0.353593], [0.5049296]], atol=1e-4)\n\n  def DISABLED_testPredictionsSequential(self):  # See b/129842600.\n    \"\"\"Tests sequential predictions with RNN model.\"\"\"\n    model = self._get_compiled_model(return_sequences=True)\n    # See `RNNClassifierPredictionTest` for details on computation.\n    self.assertAllClose(\n        model.predict(x=self.x, steps=1),\n        [[[0.191731], [0.353593]], [[0.5049296], [0.5049296]]],\n        atol=1e-4)\n\n  @parameterized.named_parameters(\n      ('StringOptimizer', 'Adam'),\n      ('OptimizerInstance', tf_keras.optimizers.Adam()))\n  def DISABLED_testTraining(self, optimizer):  # See b/129842600.\n    \"\"\"Tests the loss computed in training step.\"\"\"\n    model = self._get_compiled_model(optimizer=optimizer)\n    history = model.fit(\n        x=self.x,\n        y=ops.convert_to_tensor([[0], [1]]),\n        batch_size=1,\n        steps_per_epoch=1)\n    # See `RNNClassifierTrainingTest` for details on computation.\n    self.assertAllClose(history.history['loss'], [1.1196611], atol=1e-4)\n\n  def DISABLED_testTrainingSequential(self):  # See b/129842600.\n    \"\"\"Tests the loss computed in training step in sequential case.\"\"\"\n    model = self._get_compiled_model(return_sequences=True)\n    history = model.fit(x=self.x, y=self.y, batch_size=1, steps_per_epoch=1)\n    # See `RNNClassifierTrainingTest` for details on computation.\n    self.assertAllClose(history.history['loss'], [1.9556], atol=1e-4)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RNNEstimatorInitTest(tf.test.TestCase):\n\n  def setUp(self):\n    col = tf.feature_column.sequence_categorical_column_with_hash_bucket(\n        'tokens', hash_bucket_size=10)\n    self.feature_columns = [\n        tf.feature_column.embedding_column(col, dimension=2)\n    ]\n    self.cell_units = [4, 2]\n    super(RNNEstimatorInitTest, self).setUp()\n\n  def testConflictingRNNCellFn(self):\n    with self.assertRaisesRegexp(\n        ValueError,\n        'units and cell_type must not be specified when using rnn_cell_fn'):\n      rnn.RNNClassifier(\n          sequence_feature_columns=self.feature_columns,\n          rnn_cell_fn=lambda: 'mock-cell',\n          units=self.cell_units)\n\n    with self.assertRaisesRegexp(\n        ValueError,\n        'units and cell_type must not be specified when using rnn_cell_fn'):\n      rnn.RNNClassifier(\n          sequence_feature_columns=self.feature_columns,\n          rnn_cell_fn=lambda: 'mock-cell',\n          cell_type='lstm')\n\n  def testNonSequentialHeadProvided(self):\n    with self.assertRaisesRegexp(\n        ValueError, 'Provided head must be a `_SequentialHead` object when '\n        '`return_sequences` is set to True.'):\n      rnn.RNNEstimator(\n          head=multi_head_lib.MultiClassHead(n_classes=3),\n          sequence_feature_columns=self.feature_columns,\n          return_sequences=True)\n\n  def testWrongOptimizerTypeProvided(self):\n    classifier = rnn.RNNClassifier(\n        self.feature_columns, units=[1], optimizer=object())\n    with self.assertRaisesRegexp(\n        ValueError,\n        'The given object is not a tf_keras.optimizers.Optimizer instance.'):\n      classifier.model_fn(\n          features=None, labels=None, mode=model_fn.ModeKeys.TRAIN, config=None)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RNNClassifierTrainingTest(tf.test.TestCase):\n\n  def setUp(self):\n    self.kernel = [[.1, -.2]]\n    self.recurrent = [[.2, -.3], [.3, -.4]]\n    self.bias = [.2, .5]\n    self.dense_kernel = [[-1.], [1.]]\n    self.dense_bias = [0.3]\n    self.sequence_feature_columns = [\n        tf.feature_column.sequence_numeric_column('price', shape=(1,))\n    ]\n    super(RNNClassifierTrainingTest, self).setUp()\n\n  def _assert_checkpoint(self, n_classes, input_units, cell_units,\n                         expected_global_step):\n\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self.get_temp_dir())\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self.get_temp_dir(),\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    # RNN Cell variables.\n    for i, cell_unit in enumerate(cell_units):\n      name_suffix = '_%d' % i if i else ''\n      self.assertEqual([input_units, cell_unit],\n                       shapes[CELL_KERNEL_NAME + name_suffix])\n      self.assertEqual([cell_unit, cell_unit],\n                       shapes[CELL_RECURRENT_KERNEL_NAME + name_suffix])\n      self.assertEqual([cell_unit], shapes[CELL_BIAS_NAME + name_suffix])\n      input_units = cell_unit\n\n    # Logits variables.\n    logits_dimension = n_classes if n_classes > 2 else 1\n    self.assertEqual([cell_units[-1], logits_dimension],\n                     shapes[LOGITS_WEIGHTS_NAME])\n    self.assertEqual([logits_dimension], shapes[LOGITS_BIAS_NAME])\n\n  def _mock_optimizer(self, expected_loss=None):\n    var_names = (CELL_BIAS_NAME, CELL_KERNEL_NAME, CELL_RECURRENT_KERNEL_NAME,\n                 LOGITS_BIAS_NAME, LOGITS_WEIGHTS_NAME)\n    expected_var_names = ['%s:0' % name for name in var_names]\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n      \"\"\"Mock optimizer checking that loss has the proper value.\"\"\"\n\n      def __init__(self, test_case):\n        super(_Optimizer, self).__init__(name='my-optimizer')\n        self.call_count = 0\n        self._test_case = test_case\n\n      def get_updates(self, loss, params):\n        self.call_count += 1\n        trainable_vars = tf.compat.v1.get_collection(\n            tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n        self._test_case.assertItemsEqual(expected_var_names,\n                                         [var.name for var in trainable_vars])\n\n        # Verify loss. We can't check the value directly so we add an assert op.\n        self._test_case.assertEquals(0, loss.shape.ndims)\n        if expected_loss is None:\n          return [self.iterations.assign_add(1).op]\n        assert_loss = _assert_close(\n            tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n            loss,\n            name='assert_loss')\n        with tf.control_dependencies((assert_loss,)):\n          return [self.iterations.assign_add(1).op]\n\n      def get_config(self):\n        pass\n\n    return _Optimizer(test_case=self)\n\n  def _testFromScratchWithDefaultOptimizer(self, n_classes):\n\n    def train_input_fn():\n      return {\n          'tokens':\n              tf.sparse.SparseTensor(\n                  values=['the', 'cat', 'sat'],\n                  indices=[[0, 0], [0, 1], [0, 2]],\n                  dense_shape=[1, 3]),\n      }, [[1]]\n\n    col = tf.feature_column.sequence_categorical_column_with_hash_bucket(\n        'tokens', hash_bucket_size=10)\n    embed = tf.feature_column.embedding_column(col, dimension=2)\n    input_units = 2\n\n    cell_units = [4, 2]\n    est = rnn.RNNClassifier(\n        sequence_feature_columns=[embed],\n        units=cell_units,\n        n_classes=n_classes,\n        model_dir=self.get_temp_dir())\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    est.train(input_fn=train_input_fn, steps=num_steps)\n    self._assert_checkpoint(n_classes, input_units, cell_units, num_steps)\n\n  def testBinaryClassFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=2)\n\n  def testMultiClassFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=4)\n\n  def testFromScratchWithCustomRNNCellFn(self):\n\n    def train_input_fn():\n      return {\n          'tokens':\n              tf.sparse.SparseTensor(\n                  values=['the', 'cat', 'sat'],\n                  indices=[[0, 0], [0, 1], [0, 2]],\n                  dense_shape=[1, 3]),\n      }, [[1]]\n\n    col = tf.feature_column.sequence_categorical_column_with_hash_bucket(\n        'tokens', hash_bucket_size=10)\n    embed = tf.feature_column.embedding_column(col, dimension=2)\n    input_units = 2\n    cell_units = [4, 2]\n    n_classes = 2\n\n    def rnn_cell_fn():\n      cells = [tf_keras.layers.SimpleRNNCell(units=n) for n in cell_units]\n      return tf_keras.layers.StackedRNNCells(cells)\n\n    est = rnn.RNNClassifier(\n        sequence_feature_columns=[embed],\n        rnn_cell_fn=rnn_cell_fn,\n        n_classes=n_classes,\n        model_dir=self.get_temp_dir())\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    est.train(input_fn=train_input_fn, steps=num_steps)\n    self._assert_checkpoint(n_classes, input_units, cell_units, num_steps)\n\n  def _testExampleWeight(self, n_classes):\n\n    def train_input_fn():\n      return {\n          'tokens':\n              tf.sparse.SparseTensor(\n                  values=['the', 'cat', 'sat', 'dog', 'barked'],\n                  indices=[[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]],\n                  dense_shape=[2, 3]),\n          'w': [[1], [2]],\n      }, [[1], [0]]\n\n    col = tf.feature_column.sequence_categorical_column_with_hash_bucket(\n        'tokens', hash_bucket_size=10)\n    embed = tf.feature_column.embedding_column(col, dimension=2)\n    input_units = 2\n\n    cell_units = [4, 2]\n    est = rnn.RNNClassifier(\n        units=cell_units,\n        sequence_feature_columns=[embed],\n        n_classes=n_classes,\n        weight_column='w',\n        model_dir=self.get_temp_dir())\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    est.train(input_fn=train_input_fn, steps=num_steps)\n    self._assert_checkpoint(n_classes, input_units, cell_units, num_steps)\n\n  def testBinaryClassWithExampleWeight(self):\n    self._testExampleWeight(n_classes=2)\n\n  def testMultiClassWithExampleWeight(self):\n    self._testExampleWeight(n_classes=4)\n\n  def _testFromCheckpoint(self, input_fn, expected_loss, **kwargs):\n    \"\"\"Loads classifier from checkpoint, runs training and checks loss.\"\"\"\n    create_checkpoint(\n        kernel=self.kernel,\n        recurrent=self.recurrent,\n        bias=self.bias,\n        dense_kernel=self.dense_kernel,\n        dense_bias=self.dense_bias,\n        global_step=100,\n        model_dir=self.get_temp_dir())\n\n    mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)\n\n    est = rnn.RNNClassifier(\n        units=[2],\n        sequence_feature_columns=self.sequence_feature_columns,\n        optimizer=mock_optimizer,\n        model_dir=self.get_temp_dir(),\n        **kwargs)\n    self.assertEqual(0, mock_optimizer.call_count)\n    est.train(input_fn=input_fn, steps=10)\n    self.assertEqual(1, mock_optimizer.call_count)\n\n  def testBinaryClassFromCheckpoint(self):\n\n    def train_input_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2.],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2]),\n      }, [[0], [1]]\n\n    # Uses same checkpoint and examples as testBinaryClassEvaluationMetrics.\n    # See that test for loss calculation.\n    self._testFromCheckpoint(train_input_fn, expected_loss=0.559831)\n\n  def testMultiClassFromCheckpoint(self):\n\n    def train_input_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2., 7.],\n                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],\n                  dense_shape=[2, 2]),\n      }, [[0], [1]]\n\n    # Uses same checkpoint and examples as testMultiClassEvaluationMetrics.\n    # See that test for loss calculation.\n    self.dense_kernel = [[-1., 0.5, 0.2], [1., -0.3, 0.1]]\n    self.dense_bias = [0.3, 0.4, 0.5]\n    self._testFromCheckpoint(\n        train_input_fn, expected_loss=1.331465, n_classes=3)\n\n  def testBinaryClassFromCheckpointSequential(self):\n\n    def train_input_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2.],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2]),\n      }, tf.sparse.SparseTensor(\n          values=[0, 1, 0],\n          indices=[[0, 0], [0, 1], [1, 0]],\n          dense_shape=[2, 2])\n\n    # Same example as testBinaryClassEvaluationMetricsSequential.\n    # logits = [[[-1.4388], [-0.6033]],\n    #            [[0.0197], [_]]]\n    # probability = np.exp(logits) / (1 + np.exp(logits))\n    #             = [[0.1917, 0.3536],\n    #                [0.5049, _]]\n    # loss = -label * ln(p) - (1 - label) * ln(1 - p)\n    # loss = [[0.2129,  1.0396],\n    #         [0.7031, _]]\n    # aggregated_loss = sum(loss) / 3\n    # aggregated_loss = 0.6518\n    self._testFromCheckpoint(\n        train_input_fn, expected_loss=0.651841, return_sequences=True)\n\n  def testBinaryClassFromCheckpointSequentialWithWeights(self):\n\n    def train_input_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2.],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2]),\n          'weights':\n              tf.sparse.SparseTensor(\n                  values=[0., 0.5, 0.5],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2])\n      }, tf.sparse.SparseTensor(\n          values=[0, 0, 1],\n          indices=[[0, 0], [0, 1], [1, 0]],\n          dense_shape=[2, 2])\n\n    # Checkpoint and input are the same as testBinaryClassEvaluationMetrics, and\n    # expected loss is the same as we use non-zero weights only for the last\n    # step of each sequence.\n    # loss = [[_,  0.436326],\n    #         [0.6833351, _]]\n    # weights = [[0, 0.5], [0.5, 0]]\n    # aggregated_loss = (0.436326 + 0.6833351) / 2.\n    #                 = 0.559831\n    self._testFromCheckpoint(\n        train_input_fn,\n        expected_loss=0.559831,\n        return_sequences=True,\n        weight_column='weights',\n        loss_reduction=tf_keras.losses.Reduction.SUM)\n\n  def testDefaultGradientClipping(self):\n    \"\"\"Tests that optimizer applies default gradient clipping value.\"\"\"\n\n    def train_input_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[\n                      1.,\n                  ], indices=[[0, 0]], dense_shape=[1, 1]),\n      }, [[1]]\n\n    def _wrap_create_estimator_spec(create_estimator_spec):\n      \"\"\"Wraps function and asserts that the optimizer applies clipping.\"\"\"\n\n      def _wrapped_create_estimator_spec(obj,\n                                         features,\n                                         mode,\n                                         logits,\n                                         labels=None,\n                                         optimizer=None,\n                                         trainable_variables=None,\n                                         train_op_fn=None,\n                                         update_ops=None,\n                                         regularization_losses=None):\n        var = tf.Variable([1.0])\n        mock_loss = 10 * var\n        gradients = optimizer.get_gradients(mock_loss, [var])\n        self.assertLen(gradients, 1)\n        # Initial gradient value is 10 and expected to be clipped to 5 (default\n        # clipping value).\n        with tf.control_dependencies(\n            (tf.compat.v1.debugging.assert_equal(gradients[0], 5.0),)):\n          return create_estimator_spec(obj, features, mode, logits, labels,\n                                       optimizer, trainable_variables,\n                                       train_op_fn, update_ops,\n                                       regularization_losses)\n\n      return _wrapped_create_estimator_spec\n\n    with tf.compat.v1.test.mock.patch.object(\n        multi_head_lib.MultiClassHead, 'create_estimator_spec',\n        _wrap_create_estimator_spec(\n            multi_head_lib.MultiClassHead.create_estimator_spec)):\n      est = rnn.RNNClassifier(\n          n_classes=3,\n          sequence_feature_columns=[\n              tf.feature_column.sequence_numeric_column('price')\n          ],\n          units=[2],\n          model_dir=self.get_temp_dir())\n      est.train(input_fn=train_input_fn, steps=1)\n\n\ndef sorted_key_dict(unsorted_dict):\n  return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RNNClassifierEvaluationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self.kernel = [[.1, -.2]]\n    self.recurrent = [[.2, -.3], [.3, -.4]]\n    self.bias = [.2, .5]\n    self.dense_kernel = [[-1.], [1.]]\n    self.dense_bias = [0.3]\n    self.global_step = 100\n    self.sequence_feature_columns = [\n        tf.feature_column.sequence_numeric_column('price', shape=(1,))\n    ]\n    super(RNNClassifierEvaluationTest, self).setUp()\n\n  def _testFromCheckpoint(self, input_fn, **kwargs):\n    create_checkpoint(\n        kernel=self.kernel,\n        recurrent=self.recurrent,\n        bias=self.bias,\n        dense_kernel=self.dense_kernel,\n        dense_bias=self.dense_bias,\n        global_step=self.global_step,\n        model_dir=self.get_temp_dir())\n\n    est = rnn.RNNClassifier(\n        units=[2],\n        sequence_feature_columns=self.sequence_feature_columns,\n        model_dir=self.get_temp_dir(),\n        **kwargs)\n    return est.evaluate(input_fn, steps=1)\n\n  def testBinaryClassEvaluationMetrics(self):\n\n    def eval_input_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2.],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2]),\n      }, [[0], [1]]\n\n    eval_metrics = self._testFromCheckpoint(eval_input_fn)\n\n    # Uses identical numbers to testMultiExamplesWithDifferentLength.\n    # See that test for logits calculation.\n    # logits = [[-0.603282], [0.019719]]\n    # probability = exp(logits) / (1 + exp(logits)) = [[0.353593], [0.504930]]\n    # loss = -label * ln(p) - (1 - label) * ln(1 - p)\n    #      = [[0.436326], [0.683335]]\n    # sum_over_batch_size = (0.436326 + 0.683335)/2\n    expected_metrics = {\n        tf.compat.v1.GraphKeys.GLOBAL_STEP: self.global_step,\n        metric_keys.MetricKeys.LOSS: 0.559831,\n        metric_keys.MetricKeys.LOSS_MEAN: 0.559831,\n        metric_keys.MetricKeys.ACCURACY: 1.0,\n        metric_keys.MetricKeys.PREDICTION_MEAN: 0.429262,\n        metric_keys.MetricKeys.LABEL_MEAN: 0.5,\n        metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,\n        # With default threshold of 0.5, the model is a perfect classifier.\n        metric_keys.MetricKeys.RECALL: 1.0,\n        metric_keys.MetricKeys.PRECISION: 1.0,\n        # Positive example is scored above negative, so AUC = 1.0.\n        metric_keys.MetricKeys.AUC: 1.0,\n        metric_keys.MetricKeys.AUC_PR: 1.0,\n    }\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))\n\n  def testBinaryClassEvaluationMetricsSequential(self):\n\n    def eval_input_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2.],\n                  indices=[[0, 0], [0, 1], [1, 0]],\n                  dense_shape=[2, 2]),\n      }, tf.sparse.SparseTensor(\n          values=[0, 1, 0],\n          indices=[[0, 0], [0, 1], [1, 0]],\n          dense_shape=[2, 2])\n\n    eval_metrics = self._testFromCheckpoint(\n        eval_input_fn, return_sequences=True)\n\n    # logits = [[[-1.4388], [-0.6033]],\n    #            [[0.0197], [_]]]\n    # probability = np.exp(logits) / (1 + np.exp(logits))\n    #             = [[0.1917, 0.3536],\n    #                [0.5049, _]]\n    # labels = [[0, 1],\n    #           [0, _]]\n    # loss = -label * ln(p) - (1 - label) * ln(1 - p)\n    # loss = [[0.2129,  1.0396],\n    #         [0.7031, _]]\n    # aggregated_loss = sum(loss) / 3\n    # aggregated_loss = 0.6518\n    # accuracy = 1/3\n    # prediction_mean = mean(probability) = 0.3501\n    expected_metrics = {\n        tf.compat.v1.GraphKeys.GLOBAL_STEP: self.global_step,\n        metric_keys.MetricKeys.LOSS: 0.651841,\n        metric_keys.MetricKeys.LOSS_MEAN: 0.651841,\n        metric_keys.MetricKeys.ACCURACY: 1.0 / 3,\n        metric_keys.MetricKeys.PREDICTION_MEAN: 0.350085,\n        metric_keys.MetricKeys.LABEL_MEAN: 1.0 / 3,\n        metric_keys.MetricKeys.ACCURACY_BASELINE: 2.0 / 3,\n        metric_keys.MetricKeys.RECALL: 0.0,\n        metric_keys.MetricKeys.PRECISION: 0.0,\n        metric_keys.MetricKeys.AUC: 0.5,\n        metric_keys.MetricKeys.AUC_PR: 0.30685282,\n    }\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))\n\n  def testMultiClassEvaluationMetrics(self):\n\n    def eval_input_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5., 2., 7.],\n                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],\n                  dense_shape=[2, 2]),\n      }, [[0], [1]]\n\n    self.dense_kernel = [[-1., 0.5, 0.2], [1., -0.3, 0.1]]\n    self.dense_bias = [0.3, 0.4, 0.5]\n    # Uses identical numbers to testMultiExampleMultiDim.\n    # See that test for logits calculation.\n    # logits = [[-0.603282, 0.777708, 0.569756],\n    #           [-1.247356, 1.017018, 0.574481]]\n    # logits_exp = exp(logits) / (1 + exp(logits))\n    #            = [[0.547013, 2.176468, 1.767836],\n    #               [0.287263, 2.764937, 1.776208]]\n    # softmax_probabilities = logits_exp / logits_exp.sum()\n    #                       = [[0.121793, 0.484596, 0.393611],\n    #                          [0.059494, 0.572639, 0.367866]]\n    # loss = -1. * log(softmax[label])\n    #      = [[2.105432], [0.557500]]\n    # sum_over_batch_size = (2.105432 + 0.557500)/2\n    eval_metrics = self._testFromCheckpoint(eval_input_fn, n_classes=3)\n\n    expected_metrics = {\n        tf.compat.v1.GraphKeys.GLOBAL_STEP: self.global_step,\n        metric_keys.MetricKeys.LOSS: 1.331465,\n        metric_keys.MetricKeys.LOSS_MEAN: 1.331466,\n        metric_keys.MetricKeys.ACCURACY: 0.5,\n    }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics), sorted_key_dict(eval_metrics))\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RNNClassifierPredictionTest(tf.test.TestCase):\n\n  def setUp(self):\n    self.kernel = [[.1, -.2]]\n    self.recurrent = [[.2, -.3], [.3, -.4]]\n    self.bias = [.2, .5]\n    self.dense_kernel = [[-1.], [1.]]\n    self.dense_bias = [0.3]\n    self.sequence_feature_columns = [\n        tf.feature_column.sequence_numeric_column('price', shape=(1,))\n    ]\n    super(RNNClassifierPredictionTest, self).setUp()\n\n  def _testFromCheckpoint(self, input_fn, **kwargs):\n    create_checkpoint(\n        kernel=self.kernel,\n        recurrent=self.recurrent,\n        bias=self.bias,\n        dense_kernel=self.dense_kernel,\n        dense_bias=self.dense_bias,\n        global_step=100,\n        model_dir=self.get_temp_dir())\n\n    n_classes = 2\n    if 'n_classes' in kwargs:\n      n_classes = kwargs['n_classes']\n      assert n_classes >= 2\n    label_vocabulary = [\n        'class_{}'.format(class_idx) for class_idx in range(n_classes)\n    ]\n\n    est = rnn.RNNClassifier(\n        units=[2],\n        sequence_feature_columns=self.sequence_feature_columns,\n        label_vocabulary=label_vocabulary,\n        model_dir=self.get_temp_dir(),\n        **kwargs)\n    return next(est.predict(input_fn))\n\n  def testBinaryClassPredictions(self):\n    # Uses identical numbers to testOneDimLogits.\n    # See that test for logits calculation.\n    # logits = [-0.603282]\n    # logistic = exp(-0.6033) / (1 + exp(-0.6033)) = [0.353593]\n    # probabilities = [0.646407, 0.353593]\n    # class_ids = argmax(probabilities) = [0]\n    predictions = self._testFromCheckpoint(_default_features_fn)\n    self.assertAllClose([-0.603282],\n                        predictions[prediction_keys.PredictionKeys.LOGITS])\n    self.assertAllClose([0.353593],\n                        predictions[prediction_keys.PredictionKeys.LOGISTIC])\n    self.assertAllClose(\n        [0.646407, 0.353593],\n        predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n    self.assertAllClose([0],\n                        predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n    self.assertEqual([b'class_0'],\n                     predictions[prediction_keys.PredictionKeys.CLASSES])\n\n  def testMultiClassPredictions(self):\n    self.dense_kernel = [[-1., 0.5, 0.2], [1., -0.3, 0.1]]\n    self.dense_bias = [0.3, 0.4, 0.5]\n    # Uses identical numbers to testMultiDimLogits.\n    # See that test for logits calculation.\n    # logits = [-0.603282, 0.777708, 0.569756]\n    # logits_exp = exp(logits) = [0.547013, 2.176468, 1.767836]\n    # softmax_probabilities = logits_exp / logits_exp.sum()\n    #                       = [0.121793, 0.484596, 0.393611]\n    # class_ids = argmax(probabilities) = [1]\n    predictions = self._testFromCheckpoint(_default_features_fn, n_classes=3)\n    self.assertAllClose([-0.603282, 0.777708, 0.569756],\n                        predictions[prediction_keys.PredictionKeys.LOGITS])\n    self.assertAllClose(\n        [0.121793, 0.484596, 0.393611],\n        predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n    self.assertAllClose([1],\n                        predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n    self.assertEqual([b'class_1'],\n                     predictions[prediction_keys.PredictionKeys.CLASSES])\n\n  def testBinaryClassPredictionsSequential(self):\n\n    def predict_input_fn():\n      return {\n          'price':\n              tf.sparse.SparseTensor(\n                  values=[10., 5.],\n                  indices=[[0, 0], [0, 1]],\n                  dense_shape=[1, 2]),\n      }\n\n    # Same as first record of testBinaryClassEvaluationMetricsSequential.\n    # Last step values are carried over.\n    # logits = [[-1.4388], [-0.6033], [_]]\n    # probabilities = np.exp(logits) / (1 + np.exp(logits))\n    #               = [[0.8083, 0.1917], [0.6464, 0.3536], [_, _]]\n    # class_ids = [[0], [0], [_]]\n    # classes = [['class_0'], ['class_0'], [_]]\n    predictions = self._testFromCheckpoint(\n        predict_input_fn, return_sequences=True, sequence_mask='my-mask')\n    self.assertAllEqual([1, 1], predictions['my-mask'])\n    self.assertAllClose([[-1.438803], [-0.603282]],\n                        predictions[prediction_keys.PredictionKeys.LOGITS])\n    self.assertAllClose([[0.191731], [0.353593]],\n                        predictions[prediction_keys.PredictionKeys.LOGISTIC])\n    self.assertAllClose(\n        [[0.808269, 0.191731], [0.646407, 0.353593]],\n        predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n    self.assertAllClose([[0], [0]],\n                        predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n    self.assertAllEqual([[b'class_0'], [b'class_0']],\n                        predictions[prediction_keys.PredictionKeys.CLASSES])\n\n\nclass BaseRNNClassificationIntegrationTest(object):\n\n  def setUp(self):\n    col = tf.feature_column.sequence_categorical_column_with_hash_bucket(\n        'tokens', hash_bucket_size=10)\n    embed = tf.feature_column.embedding_column(col, dimension=2)\n    self.feature_columns = [embed]\n    super(BaseRNNClassificationIntegrationTest, self).setUp()\n\n  def __init__(self, _create_estimator_fn):\n    self._create_estimator_fn = _create_estimator_fn\n\n  def _test_complete_flow(self,\n                          train_input_fn,\n                          eval_input_fn,\n                          predict_input_fn,\n                          n_classes,\n                          batch_size,\n                          optimizer='Adam'):\n    cell_units = [4, 2]\n    est = self._create_estimator_fn(\n        self.feature_columns,\n        n_classes,\n        cell_units,\n        self.get_temp_dir(),\n        optimizer=optimizer)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUATE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predicted_proba = np.array([\n        x[prediction_keys.PredictionKeys.PROBABILITIES]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)\n\n    # EXPORT\n    feature_spec = parsing_utils.classifier_parse_example_spec(\n        self.feature_columns, label_key='label', label_dtype=tf.dtypes.int64)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_savedmodel(tempfile.mkdtemp(),\n                                       serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _testNumpyInputFn(self, optimizer):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    n_classes = 3\n    batch_size = 10\n    words = ['dog', 'cat', 'bird', 'the', 'a', 'sat', 'flew', 'slept']\n    # Numpy only supports dense input, so all examples will have same length.\n    # TODO(b/73160931): Update test when support for prepadded data exists.\n    sequence_length = 3\n\n    features = []\n    for _ in range(batch_size):\n      sentence = random.sample(words, sequence_length)\n      features.append(sentence)\n\n    x_data = np.array(features)\n    y_data = np.random.randint(n_classes, size=batch_size)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'tokens': x_data},\n        y=y_data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'tokens': x_data}, y=y_data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'tokens': x_data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        n_classes=n_classes,\n        batch_size=batch_size,\n        optimizer=optimizer)\n\n  def testNumpyInputFnStringOptimizer(self):\n    self._testNumpyInputFn(optimizer='Adam')\n\n  def testNumpyInputFnOptimizerInstance(self):\n    self._testNumpyInputFn(optimizer=tf_keras.optimizers.Adam())\n\n  def testParseExampleInputFn(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    n_classes = 3\n    batch_size = 10\n    words = [b'dog', b'cat', b'bird', b'the', b'a', b'sat', b'flew', b'slept']\n\n    _, examples_file = tempfile.mkstemp()\n    writer = tf.io.TFRecordWriter(examples_file)\n    for _ in range(batch_size):\n      sequence_length = random.randint(1, len(words))\n      sentence = random.sample(words, sequence_length)\n      label = random.randint(0, n_classes - 1)\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'tokens':\n                      feature_pb2.Feature(\n                          bytes_list=feature_pb2.BytesList(value=sentence)),\n                  'label':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(value=[label])),\n              }))\n      writer.write(example.SerializeToString())\n    writer.close()\n\n    feature_spec = parsing_utils.classifier_parse_example_spec(\n        self.feature_columns, label_key='label', label_dtype=tf.dtypes.int64)\n\n    def _train_input_fn():\n      dataset = tf.compat.v1.data.experimental.make_batched_features_dataset(\n          examples_file, batch_size, feature_spec)\n      return dataset.map(lambda features: (features, features.pop('label')))\n\n    def _eval_input_fn():\n      dataset = tf.compat.v1.data.experimental.make_batched_features_dataset(\n          examples_file, batch_size, feature_spec, num_epochs=1)\n      return dataset.map(lambda features: (features, features.pop('label')))\n\n    def _predict_input_fn():\n      dataset = tf.compat.v1.data.experimental.make_batched_features_dataset(\n          examples_file, batch_size, feature_spec, num_epochs=1)\n\n      def features_fn(features):\n        features.pop('label')\n        return features\n\n      return dataset.map(features_fn)\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n\ndef _rnn_classifier_fn(feature_columns, n_classes, cell_units, model_dir,\n                       optimizer):\n  return rnn.RNNClassifier(\n      units=cell_units,\n      sequence_feature_columns=feature_columns,\n      n_classes=n_classes,\n      optimizer=optimizer,\n      model_dir=model_dir)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RNNClassifierIntegrationTest(BaseRNNClassificationIntegrationTest,\n                                   tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    BaseRNNClassificationIntegrationTest.__init__(self, _rnn_classifier_fn)\n\n\ndef _rnn_classifier_dropout_fn(feature_columns, n_classes, cell_units,\n                               model_dir, optimizer):\n\n  def _rnn_cell_fn():\n    cells = []\n    for units in cell_units:\n      cells.append(tf_keras.layers.SimpleRNNCell(units, dropout=0.5))\n    return tf_keras.layers.StackedRNNCells(cells)\n\n  return rnn.RNNClassifier(\n      rnn_cell_fn=_rnn_cell_fn,\n      sequence_feature_columns=feature_columns,\n      n_classes=n_classes,\n      optimizer=optimizer,\n      model_dir=model_dir)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RNNClassifierDropoutIntegrationTest(BaseRNNClassificationIntegrationTest,\n                                          tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    BaseRNNClassificationIntegrationTest.__init__(self,\n                                                  _rnn_classifier_dropout_fn)\n\n\ndef _rnn_estimator_fn(feature_columns, n_classes, cell_units, model_dir,\n                      optimizer):\n  return rnn.RNNEstimator(\n      head=multi_head_lib.MultiClassHead(n_classes=n_classes),\n      units=cell_units,\n      sequence_feature_columns=feature_columns,\n      optimizer=optimizer,\n      model_dir=model_dir)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RNNEstimatorIntegrationTest(BaseRNNClassificationIntegrationTest,\n                                  tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    BaseRNNClassificationIntegrationTest.__init__(self, _rnn_estimator_fn)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass ModelFnTest(tf.test.TestCase):\n  \"\"\"Tests correctness of RNNEstimator's model function.\"\"\"\n\n  def _test_sequential_mask_in_head(self, mask=None):\n    features = {\n        'price':\n            tf.sparse.SparseTensor(\n                values=[10., 5., 4.],\n                indices=[[0, 0], [0, 1], [1, 0]],\n                dense_shape=[2, 2])\n    }\n    if mask:\n      features['sequence_mask'] = ops.convert_to_tensor(mask)\n    expected_mask = mask or [[1, 1], [1, 0]]\n\n    sequence_feature_columns = [\n        tf.feature_column.sequence_numeric_column('price', shape=(1,))\n    ]\n\n    mock_head = _get_mock_head()\n    seq_head = seq_head_lib.SequentialHeadWrapper(\n        mock_head, sequence_length_mask='sequence_mask')\n    estimator = rnn.RNNEstimator(\n        head=seq_head,\n        units=[10],\n        sequence_feature_columns=sequence_feature_columns,\n        return_sequences=True)\n    estimator.model_fn(\n        features=features,\n        labels=None,\n        mode=model_fn.ModeKeys.PREDICT,\n        config=None)\n    passed_features = list(\n        mock_head.create_estimator_spec.call_args)[1]['features']\n    self.assertIn('sequence_mask', passed_features)\n    sequence_mask = self.evaluate(passed_features['sequence_mask'])\n    self.assertAllEqual(sequence_mask, expected_mask)\n\n  def testSequentialMaskInHead(self):\n    self._test_sequential_mask_in_head()\n\n  def testSequentialMaskInHeadWithMasks(self):\n    self._test_sequential_mask_in_head([[1, 1], [1, 1]])\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/saved_model_estimator.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Class that creates an Estimator from a SavedModel.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.saved_model import constants\nfrom tensorflow.python.saved_model import loader_impl\nfrom tensorflow.python.saved_model import path_helpers\nfrom tensorflow.python.saved_model import signature_constants\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.export import export_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\nclass SavedModelEstimator(estimator_lib.EstimatorV2):\n  \"\"\"Create an Estimator from a SavedModel.\n\n  Only SavedModels exported with\n  `tf.estimator.Estimator.experimental_export_all_saved_models()` or\n  `tf.estimator.Estimator.export_saved_model()` are supported for this class.\n\n  Example with `tf.estimator.DNNClassifier`:\n\n  **Step 1: Create and train DNNClassifier.**\n\n  ```python\n  feature1 = tf.feature_column.embedding_column(\n      tf.feature_column.categorical_column_with_vocabulary_list(\n          key='feature1', vocabulary_list=('green', 'yellow')), dimension=1)\n  feature2 = tf.feature_column.numeric_column(key='feature2', default_value=0.0)\n\n  classifier = tf.estimator.DNNClassifier(\n      hidden_units=[4,2], feature_columns=[feature1, feature2])\n\n  def input_fn():\n    features = {'feature1': tf.constant(['green', 'green', 'yellow']),\n                'feature2': tf.constant([3.5, 4.2, 6.1])}\n    label = tf.constant([1., 0., 0.])\n    return tf.data.Dataset.from_tensors((features, label)).repeat()\n\n  classifier.train(input_fn=input_fn, steps=10)\n  ```\n\n  **Step 2: Export classifier.**\n  First, build functions that specify the expected inputs.\n\n  ```python\n  # During train and evaluation, both the features and labels should be defined.\n  supervised_input_receiver_fn = (\n      tf.estimator.experimental.build_raw_supervised_input_receiver_fn(\n          {'feature1': tf.placeholder(dtype=tf.string, shape=[None]),\n           'feature2': tf.placeholder(dtype=tf.float32, shape=[None])},\n          tf.placeholder(dtype=tf.float32, shape=[None])))\n\n  # During predict mode, expect to receive a `tf.Example` proto, so a parsing\n  # function is used.\n  serving_input_receiver_fn = (\n      tf.estimator.export.build_parsing_serving_input_receiver_fn(\n          tf.feature_column.make_parse_example_spec([feature1, feature2])))\n  ```\n\n  Next, export the model as a SavedModel. A timestamped directory will be\n  created (for example `/tmp/export_all/1234567890`).\n\n  ```python\n  # Option 1: Save all modes (train, eval, predict)\n  export_dir = classifier.experimental_export_all_saved_models(\n      '/tmp/export_all',\n      {tf.estimator.ModeKeys.TRAIN: supervised_input_receiver_fn,\n       tf.estimator.ModeKeys.EVAL: supervised_input_receiver_fn,\n       tf.estimator.ModeKeys.PREDICT: serving_input_receiver_fn})\n\n  # Option 2: Only export predict mode\n  export_dir = classifier.export_saved_model(\n      '/tmp/export_predict', serving_input_receiver_fn)\n  ```\n\n  **Step 3: Create a SavedModelEstimator from the exported SavedModel.**\n\n  ```python\n  est = tf.estimator.experimental.SavedModelEstimator(export_dir)\n\n  # If all modes were exported, you can immediately evaluate and predict, or\n  # continue training. Otherwise only predict is available.\n  eval_results = est.evaluate(input_fn=input_fn, steps=1)\n  print(eval_results)\n\n  est.train(input_fn=input_fn, steps=20)\n\n  def predict_input_fn():\n    example = tf.train.Example()\n    example.features.feature['feature1'].bytes_list.value.extend(['yellow'])\n    example.features.feature['feature2'].float_list.value.extend([1.])\n    return {'inputs':tf.constant([example.SerializeToString()])}\n\n  predictions = est.predict(predict_input_fn)\n  print(next(predictions))\n  ```\n  \"\"\"\n\n  def __init__(self, saved_model_dir, model_dir=None):\n    \"\"\"Initialize a SavedModelEstimator.\n\n    The SavedModelEstimator loads its model function and variable values from\n    the graphs defined in the SavedModel. There is no option to pass in\n    `RunConfig` or `params` arguments, because the model function graph is\n    defined statically in the SavedModel.\n\n    Args:\n      saved_model_dir: Directory containing SavedModel protobuf and subfolders.\n      model_dir: Directory to save new checkpoints during training.\n\n    Raises:\n      NotImplementedError: If a DistributionStrategy is defined in the config.\n        Unless the SavedModelEstimator is subclassed, this shouldn't happen.\n    \"\"\"\n\n    super(SavedModelEstimator, self).__init__(\n        model_fn=self._model_fn_from_saved_model, model_dir=model_dir)\n    if self._train_distribution or self._eval_distribution:\n      raise NotImplementedError(\n          'SavedModelEstimator currently does not support '\n          'DistributionStrategy.')\n    self.saved_model_dir = saved_model_dir\n    self.saved_model_loader = loader_impl.SavedModelLoader(saved_model_dir)\n    self._available_modes = self._extract_available_modes()\n\n  def _extract_available_modes(self):\n    \"\"\"Return list of modes found in SavedModel.\"\"\"\n    available_modes = []\n    tf.compat.v1.logging.info(\n        'Checking available modes for SavedModelEstimator.')\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      try:\n        self._get_meta_graph_def_for_mode(mode)\n      except RuntimeError:\n        tf.compat.v1.logging.warn('%s mode not found in SavedModel.' % mode)\n        continue\n\n      if self._get_signature_def_for_mode(mode) is not None:\n        available_modes.append(mode)\n\n    tf.compat.v1.logging.info('Available modes for Estimator: %s' %\n                              available_modes)\n    return available_modes\n\n  def _validate_mode(self, mode):\n    \"\"\"Make sure that mode can be run using the SavedModel.\"\"\"\n    if mode not in self._available_modes:\n      raise RuntimeError('%s mode is not available in the SavedModel. Use '\n                         'saved_model_cli to check that the Metagraph for this '\n                         'mode has been exported.' % mode)\n\n  def _get_meta_graph_def_for_mode(self, mode):\n    tags = export_lib.EXPORT_TAG_MAP[mode]\n    return self.saved_model_loader.get_meta_graph_def_from_tags(tags)\n\n  def _get_signature_def_for_mode(self, mode):\n    meta_graph_def = self._get_meta_graph_def_for_mode(mode)\n    if mode == ModeKeys.PREDICT:\n      sig_def_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n    else:\n      sig_def_key = mode\n    if sig_def_key not in meta_graph_def.signature_def:\n      tf.compat.v1.logging.warn(\n          'Metagraph for mode %s was found, but SignatureDef with'\n          ' key \\\"%s\\\" is missing.' % (mode, sig_def_key))\n      return None\n    return meta_graph_def.signature_def[sig_def_key]\n\n  def _get_saver_def_from_mode(self, mode):\n    meta_graph_def = self._get_meta_graph_def_for_mode(mode)\n    return meta_graph_def.saver_def\n\n  def _create_and_assert_global_step(self, graph):\n    # Do nothing here. The global step variable will be created/loaded from the\n    # SavedModel. If a global step variable were created here, the result\n    # will be two duplicate global step variables, causing issues during\n    # the warm-start phase.\n    # Due to the global variable being created in the model function, this may\n    # cause issues when running DistributionStrategy. Thus, DistributionStrategy\n    # is not yet supported with SavedModelEstimator.\n    return None\n\n  def _model_fn_from_saved_model(self, features, labels, mode):\n    \"\"\"Load a SavedModel graph and return an EstimatorSpec.\"\"\"\n    # TODO(kathywu): Model function loads placeholders from the graph. Calling\n    # export_all_saved_models creates another placeholder for the inputs, on top\n    # of the original placeholders. There should be a way to avoid this.\n    self._validate_mode(mode)\n\n    g = tf.compat.v1.get_default_graph()\n    if tf.compat.v1.train.get_global_step(g) is not None:\n      raise RuntimeError(\n          'Graph must not contain a global step tensor before the SavedModel is'\n          ' loaded. Please make sure that the input function does not create a '\n          'global step.')\n\n    # Extract SignatureDef for information about the input and output tensors.\n    signature_def = self._get_signature_def_for_mode(mode)\n\n    # Generate input map for replacing the inputs in the SavedModel graph with\n    # the provided features and labels.\n    input_map = _generate_input_map(signature_def, features, labels)\n\n    # Create a list of the names of output tensors. When the graph is loaded,\n    # names of the output tensors may be remapped. This ensures that the correct\n    # tensors are returned in the EstimatorSpec.\n    output_tensor_names = [\n        value.name for value in six.itervalues(signature_def.outputs)\n    ]\n\n    # Load the graph. `output_tensors` contains output `Tensors` in the same\n    # same order as the `output_tensor_names` list.\n    tags = export_lib.EXPORT_TAG_MAP[mode]\n    _, output_tensors = self.saved_model_loader.load_graph(\n        g, tags, input_map=input_map, return_elements=output_tensor_names)\n\n    # Create saver object, and restore from the SavedModel `variables` directory\n    # if no checkpoints have been saved in the `model_dir`.\n    saver_obj = tf.compat.v1.train.Saver(\n        saver_def=self._get_saver_def_from_mode(mode))\n    init_fn = None\n    if not super(SavedModelEstimator, self).latest_checkpoint():\n      init_fn = self._restore_from_saver\n\n    # Create a scaffold from the MetaGraphDef that contains ops to initialize\n    # the graph. This should mirror the steps from _add_meta_graph_for_mode(),\n    # which creates a MetaGraphDef from the EstimatorSpec's scaffold.\n    # Get asset tensors, if any.\n    meta_graph_def = self._get_meta_graph_def_for_mode(mode)\n    asset_tensors_dictionary = loader_impl.get_asset_tensors(\n        self.saved_model_loader.export_dir, meta_graph_def, import_scope=None)\n    # TODO(kathywu): switch to loader_impl._get_main_op\n    scaffold = tf.compat.v1.train.Scaffold(\n        local_init_op=loader_impl._get_main_op_tensor(  # pylint: disable=protected-access\n            meta_graph_def),\n        local_init_feed_dict=asset_tensors_dictionary,\n        saver=saver_obj,\n        init_fn=init_fn)\n\n    # Ensure that a global step tensor has been created.\n    global_step_tensor = tf.compat.v1.train.get_global_step(g)\n    tf.compat.v1.train.assert_global_step(global_step_tensor)\n\n    # Extract values to return in the EstimatorSpec.\n    output_map = dict(zip(output_tensor_names, output_tensors))\n    outputs = {\n        key: output_map[value.name]\n        for key, value in six.iteritems(signature_def.outputs)\n    }\n\n    loss, predictions, metrics = _validate_and_extract_outputs(\n        mode, outputs, signature_def.method_name)\n\n    train_op = tf.compat.v1.get_collection(constants.TRAIN_OP_KEY)\n    if len(train_op) > 1:\n      raise RuntimeError('Multiple ops found in the train_op collection.')\n    train_op = None if not train_op else train_op[0]\n\n    _clear_saved_model_collections()\n    return model_fn_lib.EstimatorSpec(\n        scaffold=scaffold,\n        mode=mode,\n        loss=loss,\n        train_op=train_op,\n        predictions=predictions,\n        eval_metric_ops=metrics)\n\n  def _restore_from_saver(self, scaffold, session):\n    return scaffold.saver.restore(session,\n                                  _get_saved_model_ckpt(self.saved_model_dir))\n\n  def latest_checkpoint(self):\n    \"\"\"Returns the filename of the latest saved checkpoint.\n\n    Returns:\n      Filename of latest checkpoint in `model_dir`. If no checkpoints are found\n      in `model_dir`, then the path to the SavedModel checkpoint is returned.\n    \"\"\"\n    return (super(SavedModelEstimator, self).latest_checkpoint() or\n            _get_saved_model_ckpt(self.saved_model_dir))\n\n\ndef _get_saved_model_ckpt(saved_model_dir):\n  \"\"\"Return path to variables checkpoint in a `SavedModel` directory.\"\"\"\n  if not tf.compat.v1.gfile.Exists(\n      os.path.join(\n          path_helpers.get_variables_dir(saved_model_dir),\n          tf.compat.as_text('variables.index'))):\n    raise ValueError('Directory provided has an invalid SavedModel format: %s' %\n                     saved_model_dir)\n  return path_helpers.get_variables_path(saved_model_dir)\n\n\ndef _clear_saved_model_collections():\n  \"\"\"Clear collections that are expected empty when exporting a SavedModel.\n\n  The SavedModel builder uses these collections to track ops necessary to\n  restore the graph state. These collections are expected to be empty before\n  MetaGraphs are added to the builder.\n  \"\"\"\n  del tf.compat.v1.get_collection_ref(tf.saved_model.ASSETS_KEY)[:]\n  del tf.compat.v1.get_collection_ref(\n      tf.compat.v1.saved_model.LEGACY_INIT_OP_KEY)[:]\n  del tf.compat.v1.get_collection_ref(tf.compat.v1.saved_model.MAIN_OP_KEY)[:]\n  del tf.compat.v1.get_collection_ref(constants.TRAIN_OP_KEY)[:]\n\n\ndef _generate_input_map(signature_def, features, labels):\n  \"\"\"Return dict mapping an input tensor name to a feature or label tensor.\n\n  Args:\n    signature_def: SignatureDef loaded from SavedModel\n    features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or\n      `SparseTensor`, specifying the features to be passed to the model.\n    labels: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or\n      `SparseTensor`, specifying the labels to be passed to the model. May be\n      `None`.\n\n  Returns:\n    dict mapping string names of inputs to features or labels tensors\n\n  Raises:\n    ValueError: if SignatureDef inputs are not completely mapped by the input\n      features and labels.\n  \"\"\"\n  # Ensure that features and labels are dictionaries. If not, convert each to\n  # a dictionary with a single item. The default keys are different for features\n  # and labels.\n  features = export_lib.wrap_and_check_input_tensors(features, 'feature')\n  if labels is not None:\n    # Unlike features, labels may be None (in prediction mode)\n    labels = export_lib.wrap_and_check_input_tensors(labels, 'label')\n\n  inputs = signature_def.inputs\n  input_map = {}\n  for key, tensor_info in six.iteritems(inputs):\n    input_name = tensor_info.name\n    if ':' in input_name:\n      input_name = input_name[:input_name.find(':')]\n\n    # When tensors are used as control inputs for operations, their names are\n    # prepended with a '^' character in the GraphDef. To handle possible control\n    # flow edge cases, control input names must be included in the input map.\n    control_dependency_name = '^' + input_name\n\n    if key in features:\n      _check_same_dtype_and_shape(features[key], tensor_info, key)\n      input_map[input_name] = input_map[control_dependency_name] = features[key]\n    elif labels is not None and key in labels:\n      _check_same_dtype_and_shape(labels[key], tensor_info, key)\n      input_map[input_name] = input_map[control_dependency_name] = labels[key]\n    else:\n      raise ValueError(\n          'Key \\\"%s\\\" not found in features or labels passed in to the model '\n          'function. All required keys: %s' % (key, inputs.keys()))\n\n  return input_map\n\n\ndef _check_same_dtype_and_shape(tensor, tensor_info, name):\n  \"\"\"Validate that tensor has the same properties as the TensorInfo proto.\n\n  Args:\n    tensor: a `Tensor` object.\n    tensor_info: a `TensorInfo` proto.\n    name: Name of the input (to identify Tensor if an error is raised).\n\n  Raises:\n    ValueError: If the tensor shape or dtype don't match the TensorInfo\n  \"\"\"\n  dtype_error = (tensor.dtype != tf.dtypes.DType(tensor_info.dtype))\n  shape_error = not tensor.shape.is_compatible_with(tensor_info.tensor_shape)\n\n  if dtype_error or shape_error:\n    msg = 'Tensor shape and/or dtype validation failed for input %s:' % name\n    if dtype_error:\n      msg += ('\\n\\tExpected dtype: %s, Got: %s' %\n              (tf.dtypes.DType(tensor_info.dtype), tensor.dtype))\n    if shape_error:\n      msg += ('\\n\\tExpected shape: %s, Got: %s' %\n              (tf.TensorShape(tensor_info.tensor_shape), tensor.shape))\n\n    raise ValueError(msg)\n\n\ndef _extract_eval_metrics(output_dict):\n  \"\"\"Return a eval metric dict extracted from the output_dict.\n\n  Eval metrics consist of a value tensor and an update op. Both must be in the\n  passed-in tensor dictionary for an eval metric to be added to the returned\n  dictionary.\n\n  Args:\n    output_dict: a dict that maps strings to tensors.\n\n  Returns:\n    dict mapping strings to (value, update_op) tuples.\n  \"\"\"\n  # pylint: disable=protected-access\n  metric_ops = {}\n  separator_char = export_lib._SupervisedOutput._SEPARATOR_CHAR\n\n  for key, tensor in six.iteritems(output_dict):\n    split_key = key.split(separator_char)\n\n    # The metric name may contain the separator character, so recreate its name.\n    metric_name = separator_char.join(split_key[:-1])\n\n    if split_key[0] == export_lib._SupervisedOutput.METRICS_NAME:\n      # If the key ends with the value suffix, and there is a corresponding\n      # key ending with the update_op suffix, then add tensors to metrics dict.\n      if split_key[-1] == export_lib._SupervisedOutput.METRIC_VALUE_SUFFIX:\n        update_op = ''.join([\n            metric_name, separator_char,\n            export_lib._SupervisedOutput.METRIC_UPDATE_SUFFIX\n        ])\n        if update_op in output_dict:\n          update_op_tensor = output_dict[update_op]\n          metric_ops[metric_name] = (tensor, update_op_tensor)\n\n  # pylint: enable=protected-access\n  return metric_ops\n\n\ndef _validate_and_extract_outputs(mode, output_dict, method_name):\n  \"\"\"Extract values from SignatureDef output dictionary.\n\n  Args:\n    mode: One of the modes enumerated in `tf.estimator.ModeKeys`.\n    output_dict: dict of string SignatureDef keys to `Tensor`.\n    method_name: Method name of the SignatureDef as a string.\n\n  Returns:\n    Tuple of (\n      loss: `Tensor` object,\n      predictions: dictionary mapping string keys to `Tensor` objects,\n      metrics: dictionary mapping string keys to a tuple of two `Tensor` objects\n    )\n\n  Raises:\n    RuntimeError: raised if SignatureDef has an invalid method name for the mode\n  \"\"\"\n  # pylint: disable=protected-access\n  loss, predictions, metrics = None, None, None\n\n  if mode == ModeKeys.PREDICT:\n    predictions = output_dict\n  else:\n    # Validate that the SignatureDef's method name matches the expected name for\n    # the given mode.\n    expected_method_name = signature_constants.SUPERVISED_TRAIN_METHOD_NAME\n    if mode == ModeKeys.EVAL:\n      expected_method_name = signature_constants.SUPERVISED_EVAL_METHOD_NAME\n    if method_name != expected_method_name:\n      raise RuntimeError(\n          'Invalid SignatureDef method name for mode %s.\\n\\tExpected: %s\\n\\t'\n          'Got: %s\\nPlease ensure that the SavedModel was exported with '\n          '`tf.estimator.experimental_export_all_saved_models()`.' %\n          (mode, expected_method_name, method_name))\n\n    # Extract loss, metrics and predictions from the output dict.\n    loss = output_dict[export_lib._SupervisedOutput.LOSS_NAME]\n    metrics = _extract_eval_metrics(output_dict)\n    predictions = {\n        key: value\n        for key, value in six.iteritems(output_dict)\n        if key.split(export_lib._SupervisedOutput._SEPARATOR_CHAR)[0] == (\n            export_lib._SupervisedOutput.PREDICTIONS_NAME)\n    }\n\n  # pylint: enable=protected-access\n  return loss, predictions, metrics\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/saved_model_estimator_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for SavedModelEstimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport shutil\nimport tempfile\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.framework.ops import add_to_collection\nfrom tensorflow.python.framework.ops import GraphKeys\nfrom tensorflow.python.ops import lookup_ops\nfrom tensorflow.python.training import saver_test_utils\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.canned import saved_model_estimator\nfrom tensorflow_estimator.python.estimator.export import export_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\ndef dummy_input_fn():\n  return tf.compat.v1.data.Dataset.from_tensors(({\n      'x': tf.constant([[1], [-2]], name='feature_x')\n  }, tf.constant([[4], [-3]], name='truth'))).repeat()\n\n\ndef _serving_feature_dict():\n  return {'x': tf.constant([[5], [6]], name='feature_x')}\n\n\ndef dummy_input_fn_features_only():\n  return tf.compat.v1.data.Dataset.from_tensors(\n      _serving_feature_dict()).repeat()\n\n\ndef dummy_supervised_receiver_fn():\n  return export_lib.build_supervised_input_receiver_fn_from_input_fn(\n      dummy_input_fn)\n\n\ndef dummy_serving_receiver_fn():\n  return export_lib.build_raw_serving_input_receiver_fn(_serving_feature_dict())\n\n\ndef model_fn_diff_modes(features, labels, mode):\n  _, _ = features, labels\n  v = tf.Variable(21, name='some_var')\n  train_op = None\n  loss = tf.constant(104)\n  if mode == ModeKeys.TRAIN:\n    loss = tf.constant(105)\n    predictions = tf.constant([501])\n    train_op = tf.group(\n        tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(), 1),\n        tf.compat.v1.assign_add(v, 3))\n  elif mode == ModeKeys.EVAL:\n    loss = tf.constant(106)\n    predictions = tf.constant([502])\n  else:\n    loss = tf.constant(107)\n    predictions = tf.constant([503])\n  return model_fn_lib.EstimatorSpec(\n      mode,\n      loss=loss,\n      train_op=train_op,\n      eval_metric_ops={\n          'abs_err':\n              tf.compat.v1.metrics.mean_absolute_error(\n                  tf.constant(0), predictions)\n      },\n      predictions=predictions)\n\n\ndef model_fn_with_trackable(features, labels, mode):\n  spec = model_fn_diff_modes(features, labels, mode)\n  predictions = spec.predictions\n\n  trackable_variable_ = saver_test_utils.CheckpointedOp(name='v2')\n\n  if mode == ModeKeys.TRAIN:\n    init_op = trackable_variable_.insert('key1', 2.2)\n    add_to_collection(GraphKeys.TABLE_INITIALIZERS, init_op)\n  else:\n    looked_up = trackable_variable_.lookup('key1', 0.0)\n    predictions = tf.constant([503.0]) + looked_up\n\n  return model_fn_lib.EstimatorSpec(\n      mode,\n      loss=spec.loss,\n      train_op=spec.train_op,\n      eval_metric_ops=spec.eval_metric_ops,\n      predictions=predictions)\n\n\n@test_util.run_v1_only('b/122480158')\nclass SavedModelEstimatorTest(tf.test.TestCase):\n\n  def setUp(self):\n    super(SavedModelEstimatorTest, self).setUp()\n    self.tmpdirs = []\n\n  def tearDown(self):\n    for tmpdir in self.tmpdirs:\n      # gfile.DeleteRecursively fails in the windows cmake test, so use shutil.\n      shutil.rmtree(tmpdir, ignore_errors=True)\n    self.tmpdirs = []\n    super(SavedModelEstimatorTest, self).tearDown()\n\n  def _get_tmp_dir(self):\n    tmpdir = tempfile.mkdtemp()\n    self.tmpdirs.append(tmpdir)\n    return tmpdir\n\n  def _export_estimator(self,\n                        train=True,\n                        evaluate=True,\n                        predict=True,\n                        model_fn=model_fn_diff_modes):\n    est = estimator.Estimator(model_fn, self._get_tmp_dir())\n    est.train(input_fn=dummy_input_fn, steps=10)\n\n    input_receiver_fn_map = {}\n    if train:\n      input_receiver_fn_map[ModeKeys.TRAIN] = (dummy_supervised_receiver_fn())\n    if evaluate:\n      input_receiver_fn_map[ModeKeys.EVAL] = (dummy_supervised_receiver_fn())\n    if predict:\n      input_receiver_fn_map[ModeKeys.PREDICT] = (dummy_serving_receiver_fn())\n\n    export_base_path = self._get_tmp_dir()\n    export_dir = est.experimental_export_all_saved_models(\n        export_base_path, input_receiver_fn_map)\n    return export_dir\n\n  def test_load_all_modes(self):\n    sme = saved_model_estimator.SavedModelEstimator(self._export_estimator(),\n                                                    self._get_tmp_dir())\n    sme.train(input_fn=dummy_input_fn, steps=1)\n    sme.train(input_fn=dummy_input_fn, steps=2)\n    self.assertEqual(13, sme.get_variable_value('global_step'))\n    self.assertEqual(60, sme.get_variable_value('some_var'))\n\n    eval_results = sme.evaluate(dummy_input_fn, steps=5)\n\n    self.assertEqual(13, eval_results['global_step'])\n    self.assertEqual(106, eval_results['loss'])\n    self.assertEqual(502, eval_results['metrics/abs_err'])\n\n    predictions = next(sme.predict(dummy_input_fn_features_only))\n    self.assertDictEqual({'output': 503}, predictions)\n\n  def test_load_all_modes_no_train(self):\n    \"\"\"Ensure that all functions can be used without requiring a ckpt.\"\"\"\n    sme = saved_model_estimator.SavedModelEstimator(self._export_estimator(),\n                                                    self._get_tmp_dir())\n    eval_results = sme.evaluate(dummy_input_fn, steps=5)\n    self.assertEqual(10, eval_results['global_step'])\n    self.assertEqual(106, eval_results['loss'])\n    self.assertEqual(502, eval_results['metrics/abs_err'])\n\n    predictions = next(sme.predict(dummy_input_fn_features_only))\n    self.assertDictEqual({'output': 503}, predictions)\n\n  def test_partial_exported_estimator(self):\n    sme1 = saved_model_estimator.SavedModelEstimator(\n        self._export_estimator(train=False, predict=False), self._get_tmp_dir())\n    sme1.evaluate(dummy_input_fn, steps=5)\n    with self.assertRaisesRegexp(RuntimeError, 'train mode is not available'):\n      sme1.train(input_fn=dummy_input_fn, steps=1)\n    with self.assertRaisesRegexp(RuntimeError, 'infer mode is not available'):\n      next(sme1.predict(dummy_input_fn_features_only))\n\n    sme2 = saved_model_estimator.SavedModelEstimator(\n        self._export_estimator(evaluate=False), self._get_tmp_dir())\n    sme2.train(input_fn=dummy_input_fn, steps=1)\n    next(sme2.predict(dummy_input_fn_features_only))\n    with self.assertRaisesRegexp(RuntimeError, 'eval mode is not available'):\n      sme2.evaluate(dummy_input_fn, steps=5)\n\n  def test_with_incorrect_input(self):\n    sme = saved_model_estimator.SavedModelEstimator(self._export_estimator(),\n                                                    self._get_tmp_dir())\n\n    def bad_shape_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensors(({\n          'x': tf.constant([1, 2], dtype=tf.dtypes.int64)\n      }, tf.constant([1, 2], dtype=tf.dtypes.float32)))\n\n    with self.assertRaisesRegexp(ValueError, 'Expected shape'):\n      sme.train(bad_shape_input_fn, steps=1)\n\n    def bad_dtype_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensors(({\n          'x': tf.constant([[1], [1]], dtype=tf.dtypes.int32)\n      }, tf.constant([[1], [1]], dtype=tf.dtypes.int64)))\n\n    with self.assertRaisesRegexp(ValueError, 'Expected dtype'):\n      sme.train(bad_dtype_input_fn, steps=1)\n\n  def test_input_fn_with_global_step(self):\n    sme = saved_model_estimator.SavedModelEstimator(self._export_estimator(),\n                                                    self._get_tmp_dir())\n\n    def bad_input_fn():\n      tf.compat.v1.train.get_or_create_global_step()\n      return tf.compat.v1.data.Dataset.from_tensors(({\n          'x': tf.constant([[1], [1]], dtype=tf.dtypes.int64)\n      }, tf.constant([[1], [1]], dtype=tf.dtypes.float32)))\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 'Graph must not contain a global step tensor'):\n      sme.train(bad_input_fn, steps=1)\n\n  def test_re_export_saved_model_serving_only(self):\n    sme = saved_model_estimator.SavedModelEstimator(self._export_estimator(),\n                                                    self._get_tmp_dir())\n    sme.train(dummy_input_fn, steps=3)\n    self.assertEqual(13, sme.get_variable_value('global_step'))\n    self.assertEqual(60, sme.get_variable_value('some_var'))\n\n    predictions = next(sme.predict(dummy_input_fn_features_only))\n    self.assertDictEqual({'output': 503}, predictions)\n\n    # Export SavedModel, and test that the variable and prediction values are\n    # the same.\n    sme_export_dir = sme.export_saved_model(self._get_tmp_dir(),\n                                            dummy_serving_receiver_fn())\n\n    sme2 = saved_model_estimator.SavedModelEstimator(sme_export_dir,\n                                                     self._get_tmp_dir())\n    self.assertEqual(60, sme.get_variable_value('some_var'))\n    self.assertEqual(13, sme.get_variable_value('global_step'))\n\n    predictions = next(sme2.predict(dummy_input_fn_features_only))\n    self.assertDictEqual({'output': 503}, predictions)\n\n  def test_re_export_saved_model(self):\n    sme = saved_model_estimator.SavedModelEstimator(self._export_estimator(),\n                                                    self._get_tmp_dir())\n    self.assertDictEqual(\n        {\n            'loss': 106,\n            'metrics/abs_err': 502,\n            'global_step': 10\n        }, sme.evaluate(dummy_input_fn, steps=1))\n\n    sme.train(dummy_input_fn, steps=3)\n    self.assertDictEqual(\n        {\n            'loss': 106,\n            'metrics/abs_err': 502,\n            'global_step': 13\n        }, sme.evaluate(dummy_input_fn, steps=1))\n    self.assertEqual(60, sme.get_variable_value('some_var'))\n\n    predictions = next(sme.predict(dummy_input_fn_features_only))\n    self.assertDictEqual({'output': 503}, predictions)\n\n    # Export SavedModel for all modes\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: dummy_supervised_receiver_fn(),\n        ModeKeys.EVAL: dummy_supervised_receiver_fn(),\n        ModeKeys.PREDICT: dummy_serving_receiver_fn()\n    }\n    sme_export_dir = sme.experimental_export_all_saved_models(\n        self._get_tmp_dir(), input_receiver_fn_map)\n\n    sme2 = saved_model_estimator.SavedModelEstimator(sme_export_dir,\n                                                     self._get_tmp_dir())\n    self.assertDictEqual(\n        {\n            'loss': 106,\n            'metrics/abs_err': 502,\n            'global_step': 13\n        }, sme.evaluate(dummy_input_fn, steps=1))\n    self.assertEqual(60, sme.get_variable_value('some_var'))\n\n    sme.train(dummy_input_fn, steps=7)\n    self.assertEqual(20, sme.get_variable_value('global_step'))\n\n    predictions = next(sme2.predict(dummy_input_fn_features_only))\n    self.assertDictEqual({'output': 503}, predictions)\n\n  def test_re_export_saved_model_with_trackable(self):\n    sme = saved_model_estimator.SavedModelEstimator(\n        self._export_estimator(model_fn=model_fn_with_trackable),\n        self._get_tmp_dir())\n\n    self.assertDictEqual(\n        {\n            'loss': 106,\n            'metrics/abs_err': 502,\n            'global_step': 10\n        }, sme.evaluate(dummy_input_fn, steps=1))\n\n    sme.train(dummy_input_fn, steps=3)\n    self.assertDictEqual(\n        {\n            'loss': 106,\n            'metrics/abs_err': 502,\n            'global_step': 13\n        }, sme.evaluate(dummy_input_fn, steps=1))\n    self.assertEqual(60, sme.get_variable_value('some_var'))\n\n    predictions = next(sme.predict(dummy_input_fn_features_only))\n    self.assertIn('output', predictions)\n    self.assertAlmostEqual(505.2, predictions['output'], places=4)\n\n    # Export SavedModel for all modes\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: dummy_supervised_receiver_fn(),\n        ModeKeys.EVAL: dummy_supervised_receiver_fn(),\n        ModeKeys.PREDICT: dummy_serving_receiver_fn()\n    }\n    sme_export_dir = sme.experimental_export_all_saved_models(\n        self._get_tmp_dir(), input_receiver_fn_map)\n\n    sme2 = saved_model_estimator.SavedModelEstimator(sme_export_dir,\n                                                     self._get_tmp_dir())\n\n    sme2.train(dummy_input_fn, steps=7)\n    self.assertEqual(20, sme2.get_variable_value('global_step'))\n    self.assertEqual(\n        81,  # 81 = 60 (last value) + 3 (step) * 7 (steps)\n        sme2.get_variable_value('some_var'))\n\n    predictions = next(sme2.predict(dummy_input_fn_features_only))\n    self.assertIn('output', predictions)\n    self.assertAlmostEqual(505.2, predictions['output'], places=4)\n\n  def test_load_saved_model_from_serving_only(self):\n\n    def model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant([103]),\n          train_op=tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(),\n                                           1),\n          predictions=tf.constant([502]),\n          export_outputs={\n              'test': export_lib.ClassificationOutput(tf.constant([[32.]]))\n          })\n\n    est = estimator.Estimator(model_fn, self._get_tmp_dir())\n    est.train(input_fn=dummy_input_fn, steps=10)\n\n    def serving_input_receiver_fn():\n      return export_lib.ServingInputReceiver(\n          {'test-features': tf.constant([[1], [1]])},\n          tf.compat.v1.placeholder(dtype=tf.dtypes.string))\n\n    export_dir = est.export_saved_model(self._get_tmp_dir(),\n                                        serving_input_receiver_fn)\n\n    sme = saved_model_estimator.SavedModelEstimator(export_dir,\n                                                    self._get_tmp_dir())\n\n    def input_fn():\n      return {'inputs': tf.constant('someinputstr')}\n\n    prediction = next(sme.predict(input_fn))\n    self.assertDictEqual({'scores': 32}, prediction)\n\n  def test_with_local_init_op(self):\n\n    def model_fn(features, labels, mode):\n      _, _ = features, labels\n      v = tf.Variable(21, name='some_var')\n      scaffold = tf.compat.v1.train.Scaffold(\n          local_init_op=tf.compat.v1.assign_add(v, -3).op)\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          scaffold=scaffold,\n          train_op=tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(),\n                                           1),\n          loss=tf.identity(v))\n\n    export_dir = self._export_estimator(predict=False, model_fn=model_fn)\n    sme = saved_model_estimator.SavedModelEstimator(export_dir,\n                                                    self._get_tmp_dir())\n\n    eval_results1 = sme.evaluate(dummy_input_fn, steps=2)\n    self.assertEqual(15, eval_results1['loss'])\n\n    sme.train(dummy_input_fn, steps=1)\n    self.assertEqual(15, sme.get_variable_value('some_var'))\n\n    eval_results2 = sme.evaluate(dummy_input_fn, steps=5)\n    self.assertEqual(12, eval_results2['loss'])\n\n  def test_with_assets(self):\n    filename = 'test_asset'\n    tmpdir = tempfile.mkdtemp()\n    absolute_filepath = os.path.join(tmpdir, filename)\n    num_buckets = 1000\n    with open(absolute_filepath, 'w') as f:\n      f.write(six.ensure_str(b'test'))\n\n    def model_fn(features, labels, mode):\n      _, _ = features, labels\n      v = tf.Variable(0, name='some_var', dtype=tf.dtypes.int64)\n      # We verify the value of filepath_tensor is replaced with a path to the\n      # saved model's assets directory by assigning a hash of filepath_tensor\n      # to some_var.\n      filepath_tensor = ops.convert_to_tensor(absolute_filepath)\n      tf.compat.v1.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS,\n                                     filepath_tensor)\n      scaffold = tf.compat.v1.train.Scaffold(\n          local_init_op=tf.compat.v1.assign(\n              v, tf.strings.to_hash_bucket_fast(filepath_tensor,\n                                                num_buckets)).op)\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          scaffold=scaffold,\n          train_op=tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(),\n                                           1),\n          loss=tf.identity(0))\n\n    export_dir = self._export_estimator(predict=False, model_fn=model_fn)\n    sme = saved_model_estimator.SavedModelEstimator(export_dir,\n                                                    self._get_tmp_dir())\n\n    with self.session() as sess:\n      expected_bucket = sess.run(\n          tf.strings.to_hash_bucket_fast(\n              os.path.join(\n                  six.ensure_str(export_dir),\n                  six.ensure_str(tf.saved_model.ASSETS_DIRECTORY),\n                  six.ensure_str(filename)), num_buckets))\n\n    sme.train(dummy_input_fn, steps=1)\n    self.assertEqual(expected_bucket, sme.get_variable_value('some_var'))\n\n  def test_with_working_input_fn(self):\n\n    def model_fn(features, labels, mode):\n      loss = None\n      if labels is not None:\n        loss = labels[0][0] + labels[1][0]\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=loss,\n          train_op=tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(),\n                                           1),\n          predictions={\n              'features_0': tf.identity([features['x'][0][0]]),\n              'features_1': tf.identity([features['x'][1][0]])\n          })\n\n    sme = saved_model_estimator.SavedModelEstimator(\n        self._export_estimator(model_fn=model_fn), self._get_tmp_dir())\n    eval_results = sme.evaluate(dummy_input_fn, steps=1)\n    self.assertEqual(1, eval_results['loss'])\n\n    predictions = next(sme.predict(dummy_input_fn_features_only))\n    self.assertDictEqual({'features_0': 5, 'features_1': 6}, predictions)\n\n  def test_control_dependency(self):\n    # Control dependencies are saved with \"^\" appended to the start of the input\n    # name. The input map must include control dependencies as well.\n    def model_fn(features, labels, mode):\n      _ = labels\n      with tf.control_dependencies([features['x']]):\n        loss = features['x'][1][0]\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=loss,\n          train_op=tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(),\n                                           1))\n\n    sme = saved_model_estimator.SavedModelEstimator(\n        self._export_estimator(train=False, predict=False, model_fn=model_fn),\n        self._get_tmp_dir())\n    sme.evaluate(dummy_input_fn, steps=1)  # Should run without error\n\n  def test_saveable_resources(self):\n\n    def model_fn(features, labels, mode):\n      tb = lookup_ops.MutableHashTable(\n          key_dtype=tf.dtypes.int32,\n          value_dtype=tf.dtypes.int32,\n          default_value=-1)\n      predictions = tb.lookup(features['x'])\n      train_op = None\n      if mode == ModeKeys.TRAIN:\n        train_op = tf.group(\n            tb.insert(features['x'], labels),\n            tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(), 1))\n      return model_fn_lib.EstimatorSpec(\n          mode, loss=tf.constant(0), predictions=predictions, train_op=train_op)\n\n    # Trains the model so that the table maps 1 -> 4, and -2 -> -3\n    # (see dummy_input_fn)\n    sme = saved_model_estimator.SavedModelEstimator(\n        self._export_estimator(model_fn=model_fn), self._get_tmp_dir())\n\n    def gen_input_fn(features, labels=None):\n\n      def fn():\n        if labels:\n          t = ({\n              'x': tf.constant(features, name='feature_x')\n          }, tf.constant(labels, name='truth'))\n        else:\n          t = {'x': tf.constant(features, name='feature_x')}\n        return tf.compat.v1.data.Dataset.from_tensors(t).repeat()\n\n      return fn\n\n    self.assertAllEqual([-1], next(sme.predict(gen_input_fn([[5]])))['output'])\n    self.assertAllEqual([4], next(sme.predict(gen_input_fn([[1]])))['output'])\n    sme.train(gen_input_fn([[5]], [[6]]), steps=1)\n    self.assertAllEqual([6], next(sme.predict(gen_input_fn([[5]])))['output'])\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/testdata/wire_vocabulary.txt",
    "content": "omar\nstringer\nmarlo\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/BUILD",
    "content": "# Placeholder: load py_library\nload(\"//tensorflow_estimator:estimator.bzl\", \"py_test\")\n\npackage(default_visibility = [\"//tensorflow_estimator:__subpackages__\"])\n\nlicenses([\"notice\"])\n\npy_library(\n    name = \"feature_keys\",\n    srcs = [\n        \"feature_keys.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\"],\n)\n\npy_library(\n    name = \"saved_model_utils\",\n    srcs = [\n        \"saved_model_utils.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":feature_keys\",\n        \":head\",\n        \":model_utils\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"model\",\n    srcs = [\n        \"model.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":feature_keys\",\n        \":math_utils\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"estimators\",\n    srcs = [\n        \"estimators.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":ar_model\",\n        \":feature_keys\",\n        \":head\",\n        \":math_utils\",\n        \":saved_model_utils\",\n        \":state_management\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"estimators_test\",\n    srcs = [\n        \"estimators_test.py\",\n    ],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    tags = [\n        \"notap\",  # TODO(b/132129465): Re-enable.\n    ],\n    deps = [\n        \":ar_model\",\n        \":estimators\",\n        \":feature_keys\",\n        \":saved_model_utils\",\n        \"//tensorflow_estimator/python/estimator:estimator_py\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"head\",\n    srcs = [\n        \"head.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":feature_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"head_test\",\n    srcs = [\n        \"head_test.py\",\n    ],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    deps = [\n        \":estimators\",\n        \":feature_keys\",\n        \":head\",\n        \":model\",\n        \":state_management\",\n        \"//tensorflow_estimator/python/estimator:estimator_py\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"model_utils\",\n    srcs = [\n        \"model_utils.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":feature_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_numpy_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"state_management\",\n    srcs = [\n        \"state_management.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":feature_keys\",\n        \":math_utils\",\n        \":model\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"ar_model\",\n    srcs = [\n        \"ar_model.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":feature_keys\",\n        \":model\",\n        \":model_utils\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"ar_model_test\",\n    srcs = [\n        \"ar_model_test.py\",\n    ],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    deps = [\n        \":ar_model\",\n        \":estimators\",\n        \":feature_keys\",\n        \"//tensorflow_estimator/python/estimator:estimator_py\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"ar_model_training_test\",\n    srcs = [\n        \"ar_model_training_test.py\",\n    ],\n    python_version = \"PY3\",\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    deps = [\n        \":ar_model\",\n        \":estimators\",\n        \":feature_keys\",\n        \"//tensorflow_estimator/python/estimator:estimator_py\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_library(\n    name = \"math_utils\",\n    srcs = [\n        \"math_utils.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \":feature_keys\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"math_utils_test\",\n    srcs = [\n        \"math_utils_test.py\",\n    ],\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":feature_keys\",\n        \":math_utils\",\n        \"//tensorflow_estimator/python/estimator:expect_proto_cpp_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/ar_model.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Auto-Regressive models for time series data.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import distributions\nfrom tensorflow.python.ops import gen_math_ops\nfrom tensorflow_estimator.python.estimator import estimator_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned.timeseries import model\nfrom tensorflow_estimator.python.estimator.canned.timeseries import model_utils\nfrom tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import PredictionFeatures\nfrom tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import TrainEvalFeatures\n\n\nclass LSTMPredictionModel(tf_keras.models.Model):\n  \"\"\"A simple encoder/decoder model using an LSTM.\n\n  This model does not operate on its own, but rather is a plugin to\n  `ARModel`. See `ARModel`'s constructor documentation\n  (`prediction_model_factory`) for a usage example.\n  \"\"\"\n\n  def __init__(self,\n               num_features,\n               input_window_size,\n               output_window_size,\n               num_units=128):\n    \"\"\"Construct the LSTM prediction model.\n\n    Args:\n      num_features: number of input features per time step.\n      input_window_size: Number of past time steps of data to look at when doing\n        the regression.\n      output_window_size: Number of future time steps to predict. Note that\n        setting it to > 1 empirically seems to give a better fit.\n      num_units: The number of units in the encoder and decoder LSTM cells.\n    \"\"\"\n    super(LSTMPredictionModel, self).__init__()\n    self._encoder = tf_keras.layers.LSTM(\n        num_units, name=\"encoder\", dtype=self.dtype, return_state=True)\n    self._decoder = tf_keras.layers.LSTM(\n        num_units, name=\"decoder\", dtype=self.dtype, return_sequences=True)\n    self._mean_transform = tf_keras.layers.Dense(num_features, name=\"mean_transform\")\n    self._covariance_transform = tf_keras.layers.Dense(\n        num_features, name=\"covariance_transform\")\n\n  def call(self, input_window_features, output_window_features):\n    \"\"\"Compute predictions from input and output windows.\"\"\"\n    _, state_h, state_c = self._encoder(input_window_features)\n    encoder_states = [state_h, state_c]\n    decoder_output = self._decoder(\n        output_window_features, initial_state=encoder_states)\n    predicted_mean = self._mean_transform(decoder_output)\n    predicted_covariance = gen_math_ops.exp(\n        self._covariance_transform(decoder_output))\n    return {\"mean\": predicted_mean, \"covariance\": predicted_covariance}\n\n\nclass ARModel(model.TimeSeriesModel):\n  \"\"\"Auto-regressive model, both linear and non-linear.\n\n  Features to the model include time and values of input_window_size timesteps,\n  and times for output_window_size timesteps. These are passed through a\n  configurable prediction model, and then fed to a loss function (e.g. squared\n  loss).\n\n  Note that this class can also be used to regress against time only by setting\n  the input_window_size to zero.\n\n  Each periodicity in the `periodicities` arg is divided by the\n  `num_time_buckets` into time buckets that are represented as features added\n  to the model.\n\n  A good heuristic for picking an appropriate periodicity for a given data set\n  would be the length of cycles in the data. For example, energy usage in a\n  home is typically cyclic each day. If the time feature in a home energy\n  usage dataset is in the unit of hours, then 24 would be an appropriate\n  periodicity. Similarly, a good heuristic for `num_time_buckets` is how often\n  the data is expected to change within the cycle. For the aforementioned home\n  energy usage dataset and periodicity of 24, then 48 would be a reasonable\n  value if usage is expected to change every half hour.\n\n  Each feature's value for a given example with time t is the difference\n  between t and the start of the time bucket it falls under. If it doesn't fall\n  under a feature's associated time bucket, then that feature's value is zero.\n\n  For example: if `periodicities` = (9, 12) and `num_time_buckets` = 3, then 6\n  features would be added to the model, 3 for periodicity 9 and 3 for\n  periodicity 12.\n\n  For an example data point where t = 17:\n  - It's in the 3rd time bucket for periodicity 9 (2nd period is 9-18 and 3rd\n    time bucket is 15-18)\n  - It's in the 2nd time bucket for periodicity 12 (2nd period is 12-24 and\n    2nd time bucket is between 16-20).\n\n  Therefore the 6 added features for this row with t = 17 would be:\n\n  # Feature name (periodicity#_timebucket#), feature value\n  P9_T1, 0 # not in first time bucket\n  P9_T2, 0 # not in second time bucket\n  P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd time bucket\n  P12_T1, 0 # not in first time bucket\n  P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd time bucket\n  P12_T3, 0 # not in third time bucket\n  \"\"\"\n  SQUARED_LOSS = \"squared_loss\"\n  NORMAL_LIKELIHOOD_LOSS = \"normal_likelihood_loss\"\n\n  def __init__(self,\n               periodicities,\n               input_window_size,\n               output_window_size,\n               num_features,\n               prediction_model_factory=LSTMPredictionModel,\n               num_time_buckets=10,\n               loss=NORMAL_LIKELIHOOD_LOSS,\n               exogenous_feature_columns=None):\n    \"\"\"Constructs an auto-regressive model.\n\n    Args:\n      periodicities: periodicities of the input data, in the same units as the\n        time feature (for example 24 if feeding hourly data with a daily\n        periodicity, or 60 * 24 if feeding minute-level data with daily\n        periodicity). Note this can be a single value or a list of values for\n        multiple periodicities.\n      input_window_size: Number of past time steps of data to look at when doing\n        the regression.\n      output_window_size: Number of future time steps to predict. Note that\n        setting it to > 1 empirically seems to give a better fit.\n      num_features: number of input features per time step.\n      prediction_model_factory: A callable taking arguments `num_features`,\n        `input_window_size`, and `output_window_size` and returning a\n        `tf_keras.Model`. The `Model`'s `call()` takes two arguments: an input\n        window and an output window, and returns a dictionary of predictions.\n        See `LSTMPredictionModel` for an example. The default model computes\n        predictions as a linear function of flattened input and output windows.\n      num_time_buckets: Number of buckets into which to divide (time %\n        periodicity). This value multiplied by the number of periodicities is\n        the number of time features added to the model.\n      loss: Loss function to use for training. Currently supported values are\n        SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for\n        NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For\n        SQUARED_LOSS, the evaluation loss is reported based on un-scaled\n        observations and predictions, while the training loss is computed on\n        normalized data (if input statistics are available).\n      exogenous_feature_columns: A list of `tf.feature_column`s (for example\n        `tf.feature_column.embedding_column`) corresponding to features which\n        provide extra information to the model but are not part of the series to\n        be predicted.\n\n    Example usage:\n\n    >>> model = ar_model.ARModel(\n    ...    periodicities=2, num_features=3,\n    ...    prediction_model_factory=functools.partial(\n    ...       LSTMPredictionModel, hidden_layer_sizes=[10, 10]))\n    \"\"\"\n    self._model_factory = prediction_model_factory\n    self.input_window_size = input_window_size\n    self.output_window_size = output_window_size\n    self.window_size = self.input_window_size + self.output_window_size\n    self.loss = loss\n    super(ARModel, self).__init__(\n        num_features=num_features,\n        exogenous_feature_columns=exogenous_feature_columns)\n    if exogenous_feature_columns is not None:\n      self.exogenous_size = self._get_exogenous_embedding_shape()[-1]\n    else:\n      self.exogenous_size = 0\n    assert num_time_buckets > 0\n    self._buckets = int(num_time_buckets)\n    if periodicities is None or not periodicities:\n      periodicities = []\n    elif (not isinstance(periodicities, list) and\n          not isinstance(periodicities, tuple)):\n      periodicities = [periodicities]\n    self._periodicities = [int(p) for p in periodicities]\n    for p in self._periodicities:\n      assert p > 0\n    assert len(self._periodicities) or self.input_window_size\n    assert output_window_size > 0\n\n  def initialize_graph(self, input_statistics=None):\n    super(ARModel, self).initialize_graph(input_statistics=input_statistics)\n    self._model_scope = tf.compat.v1.variable_scope(\n        # The trailing slash means we strip all enclosing variable_scopes, which\n        # unfortunately is necessary because the model gets called inside and\n        # outside a \"while\" scope (for prediction and training respectively),\n        # and the variables names need to match.\n        \"model/\",\n        use_resource=True)\n    self._model_instance = self._model_factory(\n        num_features=self.num_features,\n        input_window_size=self.input_window_size,\n        output_window_size=self.output_window_size)\n\n  def get_start_state(self):\n    # State which matches the format we'll return later. Typically this will not\n    # be used by the model directly, but the shapes and dtypes should match so\n    # that the serving input_receiver_fn gets placeholder shapes correct.\n    return (tf.zeros([self.input_window_size], dtype=tf.dtypes.int64),\n            tf.zeros([self.input_window_size, self.num_features],\n                     dtype=self.dtype),\n            tf.zeros([self.input_window_size, self.exogenous_size],\n                     dtype=self.dtype))\n\n  # TODO(allenl,agarwal): Support sampling for AR.\n  def random_model_parameters(self, seed=None):\n    pass\n\n  def generate(self,\n               number_of_series,\n               series_length,\n               model_parameters=None,\n               seed=None):\n    pass\n\n  def _predicted_covariance_op(self, activations, num_values):\n    activation, activation_size = activations[-1]\n    if self.loss == ARModel.NORMAL_LIKELIHOOD_LOSS:\n      log_sigma_square = model_utils.fully_connected(\n          activation,\n          activation_size,\n          self.output_window_size * num_values,\n          name=\"log_sigma_square\",\n          activation=None)\n      predicted_covariance = gen_math_ops.exp(log_sigma_square)\n      predicted_covariance = tf.reshape(\n          predicted_covariance, [-1, self.output_window_size, num_values])\n    else:\n      shape = tf.stack([\n          tf.compat.v1.shape(activation)[0],\n          tf.constant(self.output_window_size),\n          tf.constant(num_values)\n      ])\n      predicted_covariance = tf.ones(shape=shape, dtype=activation.dtype)\n    return predicted_covariance\n\n  def _predicted_mean_op(self, activations):\n    activation, activation_size = activations[-1]\n    predicted_mean = model_utils.fully_connected(\n        activation,\n        activation_size,\n        self.output_window_size * self.num_features,\n        name=\"predicted_mean\",\n        activation=None)\n    return tf.reshape(predicted_mean,\n                      [-1, self.output_window_size, self.num_features])\n\n  def prediction_ops(self, times, values, exogenous_regressors):\n    \"\"\"Compute model predictions given input data.\n\n    Args:\n      times: A [batch size, self.window_size] integer Tensor, the first\n        self.input_window_size times in each part of the batch indicating input\n        features, and the last self.output_window_size times indicating\n        prediction times.\n      values: A [batch size, self.input_window_size, self.num_features] Tensor\n        with input features.\n      exogenous_regressors: A [batch size, self.window_size,\n        self.exogenous_size] Tensor with exogenous features.\n\n    Returns:\n      Tuple (predicted_mean, predicted_covariance), where each element is a\n      Tensor with shape [batch size, self.output_window_size,\n      self.num_features].\n    \"\"\"\n    times.get_shape().assert_is_compatible_with([None, self.window_size])\n    batch_size = tf.compat.v1.shape(times)[0]\n    if self.input_window_size:\n      values.get_shape().assert_is_compatible_with(\n          [None, self.input_window_size, self.num_features])\n    if exogenous_regressors is not None:\n      exogenous_regressors.get_shape().assert_is_compatible_with(\n          [None, self.window_size, self.exogenous_size])\n    # Create input features.\n    input_window_features = []\n    input_feature_size = 0\n    output_window_features = []\n    output_feature_size = 0\n    if self._periodicities:\n      _, time_features = self._compute_time_features(times)\n      num_time_features = self._buckets * len(self._periodicities)\n      time_features = tf.reshape(\n          time_features, [batch_size, self.window_size, num_time_features])\n      input_time_features, output_time_features = tf.split(\n          time_features, (self.input_window_size, self.output_window_size),\n          axis=1)\n      input_feature_size += num_time_features\n      output_feature_size += num_time_features\n      input_window_features.append(input_time_features)\n      output_window_features.append(output_time_features)\n    if self.input_window_size:\n      inp = tf.slice(values, [0, 0, 0], [-1, self.input_window_size, -1])\n      input_window_features.append(\n          tf.reshape(inp,\n                     [batch_size, self.input_window_size, self.num_features]))\n      input_feature_size += self.num_features\n    if self.exogenous_size:\n      input_exogenous_features, output_exogenous_features = tf.split(\n          exogenous_regressors,\n          (self.input_window_size, self.output_window_size),\n          axis=1)\n      input_feature_size += self.exogenous_size\n      output_feature_size += self.exogenous_size\n      input_window_features.append(input_exogenous_features)\n      output_window_features.append(output_exogenous_features)\n    assert input_window_features\n    input_window_features = tf.concat(input_window_features, axis=2)\n    if output_window_features:\n      output_window_features = tf.concat(output_window_features, axis=2)\n    else:\n      output_window_features = tf.zeros(\n          [batch_size, self.output_window_size, 0], dtype=self.dtype)\n    static_batch_size = times.get_shape().dims[0].value\n    input_window_features.set_shape(\n        [static_batch_size, self.input_window_size, input_feature_size])\n    output_window_features.set_shape(\n        [static_batch_size, self.output_window_size, output_feature_size])\n    return self._output_window_predictions(input_window_features,\n                                           output_window_features)\n\n  def _output_window_predictions(self, input_window_features,\n                                 output_window_features):\n    with self._model_scope:\n      predictions = self._model_instance(input_window_features,\n                                         output_window_features)\n      result_shape = [None, self.output_window_size, self.num_features]\n      for v in predictions.values():\n        v.set_shape(result_shape)\n      return predictions\n\n  def loss_op(self, targets, prediction_ops):\n    \"\"\"Create loss_op.\"\"\"\n    prediction = prediction_ops[\"mean\"]\n    if self.loss == ARModel.NORMAL_LIKELIHOOD_LOSS:\n      covariance = prediction_ops[\"covariance\"]\n      sigma = tf.math.sqrt(tf.math.maximum(covariance, 1e-5))\n      normal = distributions.normal.Normal(loc=targets, scale=sigma)\n      loss_op = -tf.math.reduce_sum(normal.log_prob(prediction))\n    else:\n      assert self.loss == ARModel.SQUARED_LOSS, self.loss\n      loss_op = tf.math.reduce_sum(tf.math.square(prediction - targets))\n    loss_op /= tf.cast(\n        tf.math.reduce_prod(tf.compat.v1.shape(targets)), loss_op.dtype)\n    return loss_op\n\n  def _process_exogenous_features(self, times, features):\n    embedded = super(ARModel, self)._process_exogenous_features(\n        times=times, features=features)\n    if embedded is None:\n      assert self.exogenous_size == 0\n      # No embeddings. Return a zero-size [batch, times, 0] array so we don't\n      # have to special case it downstream.\n      return tf.zeros(\n          tf.concat([tf.compat.v1.shape(times),\n                     tf.constant([0])], axis=0))\n    else:\n      return embedded\n\n  # TODO(allenl, agarwal): Consider better ways of warm-starting predictions.\n  def predict(self, features):\n    \"\"\"Computes predictions multiple steps into the future.\n\n    Args:\n      features: A dictionary with the following key/value pairs:\n        PredictionFeatures.TIMES: A [batch size, predict window size] integer\n          Tensor of times, after the window of data indicated by `STATE_TUPLE`,\n          to make predictions for.\n        PredictionFeatures.STATE_TUPLE: A tuple of (times, values), times with\n          shape [batch size, self.input_window_size], values with shape [batch\n          size, self.input_window_size, self.num_features] representing a\n          segment of the time series before `TIMES`. This data is used to start\n          of the autoregressive computation. This should have data for at least\n          self.input_window_size timesteps. And any exogenous features, with\n          shapes prefixed by shape of `TIMES`.\n\n    Returns:\n      A dictionary with keys, \"mean\", \"covariance\". The\n      values are Tensors of shape [batch_size, predict window size,\n      num_features] and correspond to the values passed in `TIMES`.\n    \"\"\"\n    if not self._graph_initialized:\n      self.initialize_graph()\n    predict_times = tf.cast(\n        ops.convert_to_tensor(features[PredictionFeatures.TIMES]),\n        tf.dtypes.int32)\n    exogenous_regressors = self._process_exogenous_features(\n        times=predict_times,\n        features={\n            key: value for key, value in features.items() if key not in [\n                TrainEvalFeatures.TIMES, TrainEvalFeatures.VALUES,\n                PredictionFeatures.STATE_TUPLE\n            ]\n        })\n    with tf.control_dependencies([\n        tf.compat.v1.debugging.assert_equal(\n            tf.compat.v1.shape(predict_times)[1],\n            tf.compat.v1.shape(exogenous_regressors)[1])\n    ]):\n      exogenous_regressors = tf.identity(exogenous_regressors)\n    batch_size = tf.compat.v1.shape(predict_times)[0]\n    num_predict_values = tf.compat.v1.shape(predict_times)[1]\n    prediction_iterations = (\n        (num_predict_values + self.output_window_size - 1) //\n        self.output_window_size)\n    # Pad predict_times and exogenous regressors so as to have exact multiple of\n    # self.output_window_size values per example.\n    padding_size = (\n        prediction_iterations * self.output_window_size - num_predict_values)\n    predict_times = tf.compat.v1.pad(predict_times, [[0, 0], [0, padding_size]])\n    exogenous_regressors = tf.compat.v1.pad(exogenous_regressors,\n                                            [[0, 0], [0, padding_size], [0, 0]])\n    state = features[PredictionFeatures.STATE_TUPLE]\n    (state_times, state_values, state_exogenous_regressors) = state\n    state_times = tf.cast(ops.convert_to_tensor(state_times), tf.dtypes.int32)\n    state_values = ops.convert_to_tensor(state_values, dtype=self.dtype)\n    state_exogenous_regressors = ops.convert_to_tensor(\n        state_exogenous_regressors, dtype=self.dtype)\n\n    initial_input_times = predict_times[:, :self.output_window_size]\n    initial_input_exogenous_regressors = (\n        exogenous_regressors[:, :self.output_window_size, :])\n    if self.input_window_size > 0:\n      initial_input_times = tf.concat(\n          [state_times[:, -self.input_window_size:], initial_input_times], 1)\n      values_size = tf.compat.v1.shape(state_values)[1]\n      times_size = tf.compat.v1.shape(state_times)[1]\n      with tf.control_dependencies([\n          tf.compat.v1.debugging.assert_greater_equal(values_size,\n                                                      self.input_window_size),\n          tf.compat.v1.debugging.assert_equal(values_size, times_size)\n      ]):\n        initial_input_values = state_values[:, -self.input_window_size:, :]\n        initial_input_exogenous_regressors = tf.concat([\n            state_exogenous_regressors[:, -self.input_window_size:, :],\n            initial_input_exogenous_regressors[:, :self.output_window_size, :]\n        ],\n                                                       axis=1)\n    else:\n      initial_input_values = 0\n\n    # Iterate over the predict_times, predicting self.output_window_size values\n    # in each iteration.\n    def _while_condition(iteration_number, *unused_args):\n      return tf.math.less(iteration_number, prediction_iterations)\n\n    def _while_body(iteration_number, input_times, input_values,\n                    input_exogenous_regressors, mean_ta, covariance_ta):\n      \"\"\"Predict self.output_window_size values.\"\"\"\n      prediction_ops = self.prediction_ops(input_times, input_values,\n                                           input_exogenous_regressors)\n      predicted_mean = prediction_ops[\"mean\"]\n      predicted_covariance = prediction_ops[\"covariance\"]\n      offset = self.output_window_size * tf.math.minimum(\n          iteration_number + 1, prediction_iterations - 1)\n      if self.input_window_size > 0:\n        if self.output_window_size < self.input_window_size:\n          new_input_values = tf.concat(\n              [input_values[:, self.output_window_size:, :], predicted_mean], 1)\n          new_input_exogenous_regressors = tf.concat([\n              input_exogenous_regressors[:, -self.input_window_size:, :],\n              exogenous_regressors[\n                  :, offset:offset + self.output_window_size, :]\n          ], axis=1)\n          new_input_times = tf.concat([\n              input_times[:, -self.input_window_size:],\n              predict_times[:, offset:offset + self.output_window_size]\n          ], 1)\n        else:\n          new_input_values = predicted_mean[:, -self.input_window_size:, :]\n          new_input_exogenous_regressors = exogenous_regressors[\n              :,\n              offset - self.input_window_size:offset + self.output_window_size,\n              :]\n          new_input_times = predict_times[\n              :,\n              offset - self.input_window_size:offset + self.output_window_size]\n      else:\n        new_input_values = input_values\n        new_input_exogenous_regressors = exogenous_regressors[\n            :, offset:offset + self.output_window_size, :]\n        new_input_times = predict_times[:,\n                                        offset:offset + self.output_window_size]\n      new_input_times.set_shape(initial_input_times.get_shape())\n      new_input_exogenous_regressors.set_shape(\n          initial_input_exogenous_regressors.get_shape())\n      new_mean_ta = mean_ta.write(iteration_number, predicted_mean)\n      if isinstance(covariance_ta, tf.TensorArray):\n        new_covariance_ta = covariance_ta.write(iteration_number,\n                                                predicted_covariance)\n      else:\n        new_covariance_ta = covariance_ta\n      return (iteration_number + 1, new_input_times, new_input_values,\n              new_input_exogenous_regressors, new_mean_ta, new_covariance_ta)\n\n    # Note that control_flow_ops.while_loop doesn't seem happy with None. Hence\n    # using 0 for cases where we don't want to predict covariance.\n    covariance_ta_init = (\n        tf.TensorArray(dtype=self.dtype, size=prediction_iterations)\n        if self.loss != ARModel.SQUARED_LOSS else 0.)\n    mean_ta_init = tf.TensorArray(dtype=self.dtype, size=prediction_iterations)\n    _, _, _, _, mean_ta, covariance_ta = tf.compat.v1.while_loop(\n        _while_condition, _while_body, [\n            0, initial_input_times, initial_input_values,\n            initial_input_exogenous_regressors, mean_ta_init, covariance_ta_init\n        ])\n\n    def _parse_ta(values_ta):\n      \"\"\"Helper function to parse the returned TensorArrays.\"\"\"\n\n      if not isinstance(values_ta, tf.TensorArray):\n        return None\n      predictions_length = prediction_iterations * self.output_window_size\n      # Shape [prediction_iterations, batch_size, self.output_window_size,\n      #        self.num_features]\n      values_packed = values_ta.stack()\n      # Transpose to move batch dimension outside.\n      output_values = tf.reshape(\n          tf.compat.v1.transpose(values_packed, [1, 0, 2, 3]),\n          tf.stack([batch_size, predictions_length, -1]))\n      # Clip to desired size\n      return output_values[:, :num_predict_values, :]\n\n    predicted_mean = _parse_ta(mean_ta)\n    predicted_covariance = _parse_ta(covariance_ta)\n    if predicted_covariance is None:\n      predicted_covariance = tf.compat.v1.ones_like(predicted_mean)\n\n    # Transform and scale the mean and covariance appropriately.\n    predicted_mean = self._scale_back_data(predicted_mean)\n    predicted_covariance = self._scale_back_variance(predicted_covariance)\n\n    return {\"mean\": predicted_mean, \"covariance\": predicted_covariance}\n\n  def _process_window(self, features, mode, exogenous_regressors):\n    \"\"\"Compute model outputs on a single window of data.\"\"\"\n    times = tf.cast(features[TrainEvalFeatures.TIMES], tf.dtypes.int64)\n    values = tf.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)\n    exogenous_regressors = tf.cast(exogenous_regressors, dtype=self.dtype)\n    original_values = values\n\n    # Extra shape checking for the window size (above that in\n    # `head.create_estimator_spec`).\n    expected_times_shape = [None, self.window_size]\n    if not times.get_shape().is_compatible_with(expected_times_shape):\n      raise ValueError(\n          (\"ARModel with input_window_size={input_window_size} \"\n           \"and output_window_size={output_window_size} expects \"\n           \"feature '{times_feature}' to have shape (batch_size, \"\n           \"{window_size}) (for any batch_size), but got shape {times_shape}. \"\n           \"If you are using RandomWindowInputFn, set \"\n           \"window_size={window_size} or adjust the input_window_size and \"\n           \"output_window_size arguments to ARModel.\").format(\n               input_window_size=self.input_window_size,\n               output_window_size=self.output_window_size,\n               times_feature=TrainEvalFeatures.TIMES,\n               window_size=self.window_size,\n               times_shape=times.get_shape()))\n    values = self._scale_data(values)\n    if self.input_window_size > 0:\n      input_values = values[:, :self.input_window_size, :]\n    else:\n      input_values = None\n    prediction_ops = self.prediction_ops(times, input_values,\n                                         exogenous_regressors)\n    prediction = prediction_ops[\"mean\"]\n    covariance = prediction_ops[\"covariance\"]\n    targets = tf.slice(values, [0, self.input_window_size, 0], [-1, -1, -1])\n    targets.get_shape().assert_is_compatible_with(prediction.get_shape())\n    if (mode == estimator_lib.ModeKeys.EVAL and\n        self.loss == ARModel.SQUARED_LOSS):\n      # Report an evaluation loss which matches the expected\n      #  (observed - predicted) ** 2.\n      # Note that this affects only evaluation; the training loss is unaffected.\n      loss = self.loss_op(\n          self._scale_back_data(targets),\n          {\"mean\": self._scale_back_data(prediction_ops[\"mean\"])})\n    else:\n      loss = self.loss_op(targets, prediction_ops)\n\n    # Scale back the prediction.\n    prediction = self._scale_back_data(prediction)\n    covariance = self._scale_back_variance(covariance)\n\n    return model.ModelOutputs(\n        loss=loss,\n        end_state=(times[:, -self.input_window_size:],\n                   values[:, -self.input_window_size:, :],\n                   exogenous_regressors[:, -self.input_window_size:, :]),\n        predictions={\n            \"mean\": prediction,\n            \"covariance\": covariance,\n            \"observed\": original_values[:, -self.output_window_size:]\n        },\n        prediction_times=times[:, -self.output_window_size:])\n\n  def get_batch_loss(self, features, mode, state):\n    \"\"\"Computes predictions and a loss.\n\n    Args:\n      features: A dictionary (such as is produced by a chunker) with the\n        following key/value pairs (shapes are given as required for training):\n          TrainEvalFeatures.TIMES: A [batch size, self.window_size] integer\n            Tensor with times for each observation. To train on longer\n            sequences, the data should first be chunked.\n          TrainEvalFeatures.VALUES: A [batch size, self.window_size,\n            self.num_features] Tensor with values for each observation. When\n            evaluating, `TIMES` and `VALUES` must have a window size of at least\n            self.window_size, but it may be longer, in which case the last\n            window_size - self.input_window_size times (or fewer if this is not\n            divisible by self.output_window_size) will be evaluated on with\n            non-overlapping output windows (and will have associated\n            predictions). This is primarily to support qualitative\n            evaluation/plotting, and is not a recommended way to compute\n            evaluation losses (since there is no overlap in the output windows,\n            which for window-based models is an undesirable bias).\n      mode: The tf.estimator.ModeKeys mode to use (TRAIN or EVAL).\n      state: Unused\n\n    Returns:\n      A model.ModelOutputs object.\n    Raises:\n      ValueError: If `mode` is not TRAIN or EVAL, or if static shape information\n      is incorrect.\n    \"\"\"\n    features = {\n        feature_name: ops.convert_to_tensor(feature_value)\n        for feature_name, feature_value in features.items()\n    }\n    times = features[TrainEvalFeatures.TIMES]\n    exogenous_regressors = self._process_exogenous_features(\n        times=times,\n        features={\n            key: value for key, value in features.items() if key not in [\n                TrainEvalFeatures.TIMES, TrainEvalFeatures.VALUES,\n                PredictionFeatures.STATE_TUPLE\n            ]\n        })\n    if mode == estimator_lib.ModeKeys.TRAIN:\n      # For training, we require the window size to be self.window_size as\n      # iterating sequentially on larger windows could introduce a bias.\n      return self._process_window(\n          features, mode=mode, exogenous_regressors=exogenous_regressors)\n    elif mode == estimator_lib.ModeKeys.EVAL:\n      # For evaluation, we allow the user to pass in a larger window, in which\n      # case we try to cover as much of the window as possible without\n      # overlap. Quantitative evaluation is more efficient/correct with fixed\n      # windows matching self.window_size (as with training), but this looping\n      # allows easy plotting of \"in-sample\" predictions.\n      times.get_shape().assert_has_rank(2)\n      static_window_size = times.get_shape().dims[1].value\n      if (static_window_size is not None and\n          static_window_size < self.window_size):\n        raise ValueError(\n            (\"ARModel requires a window of at least input_window_size + \"\n             \"output_window_size to evaluate on (input_window_size={}, \"\n             \"output_window_size={}, and got shape {} for feature '{}' (batch \"\n             \"size, window size)).\").format(self.input_window_size,\n                                            self.output_window_size,\n                                            times.get_shape(),\n                                            TrainEvalFeatures.TIMES))\n      num_iterations = (\n          (tf.compat.v1.shape(times)[1] - self.input_window_size) //\n          self.output_window_size)\n      output_size = num_iterations * self.output_window_size\n      # Rather than dealing with overlapping windows of output, discard a bit at\n      # the beginning if output windows don't cover evenly.\n      crop_length = output_size + self.input_window_size\n      features = {\n          feature_name: feature_value[:, -crop_length:]\n          for feature_name, feature_value in features.items()\n      }\n\n      # Note that, unlike the ARModel's predict() while_loop, each iteration\n      # here can run in parallel, since we are not feeding predictions or state\n      # from previous iterations.\n      def _while_condition(iteration_number, loss_ta, mean_ta, covariance_ta):\n        del loss_ta, mean_ta, covariance_ta  # unused\n        return iteration_number < num_iterations\n\n      def _while_body(iteration_number, loss_ta, mean_ta, covariance_ta):\n        \"\"\"Perform a processing step on a single window of data.\"\"\"\n        base_offset = iteration_number * self.output_window_size\n        model_outputs = self._process_window(\n            features={\n                feature_name:\n                feature_value[:, base_offset:base_offset + self.window_size]\n                for feature_name, feature_value in features.items()\n            },\n            mode=mode,\n            exogenous_regressors=exogenous_regressors[:,\n                                                      base_offset:base_offset +\n                                                      self.window_size])\n        # This code needs to be updated if new predictions are added in\n        # self._process_window\n        assert len(model_outputs.predictions) == 3\n        assert \"mean\" in model_outputs.predictions\n        assert \"covariance\" in model_outputs.predictions\n        assert \"observed\" in model_outputs.predictions\n        return (iteration_number + 1,\n                loss_ta.write(iteration_number, model_outputs.loss),\n                mean_ta.write(iteration_number,\n                              model_outputs.predictions[\"mean\"]),\n                covariance_ta.write(iteration_number,\n                                    model_outputs.predictions[\"covariance\"]))\n\n      _, loss_ta, mean_ta, covariance_ta = tf.compat.v1.while_loop(\n          _while_condition, _while_body, [\n              0,\n              tf.TensorArray(dtype=self.dtype, size=num_iterations),\n              tf.TensorArray(dtype=self.dtype, size=num_iterations),\n              tf.TensorArray(dtype=self.dtype, size=num_iterations)\n          ])\n      values = tf.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)\n      batch_size = tf.compat.v1.shape(times)[0]\n      prediction_shape = [\n          batch_size, self.output_window_size * num_iterations,\n          self.num_features\n      ]\n      (previous_state_times, previous_state_values,\n       previous_state_exogenous_regressors) = state\n      # Make sure returned state always has windows of self.input_window_size,\n      # even if we were passed fewer than self.input_window_size points this\n      # time.\n      if self.input_window_size > 0:\n        new_state_times = tf.concat(\n            [previous_state_times,\n             tf.cast(times, dtype=tf.dtypes.int64)],\n            axis=1)[:, -self.input_window_size:]\n        new_state_times.set_shape((None, self.input_window_size))\n        new_state_values = tf.concat(\n            [previous_state_values,\n             self._scale_data(values)], axis=1)[:, -self.input_window_size:, :]\n        new_state_values.set_shape(\n            (None, self.input_window_size, self.num_features))\n        new_exogenous_regressors = tf.concat(\n            [previous_state_exogenous_regressors, exogenous_regressors],\n            axis=1)[:, -self.input_window_size:, :]\n        new_exogenous_regressors.set_shape(\n            (None, self.input_window_size, self.exogenous_size))\n      else:\n        # There is no state to keep, and the strided slices above do not handle\n        # input_window_size=0.\n        new_state_times = previous_state_times\n        new_state_values = previous_state_values\n        new_exogenous_regressors = previous_state_exogenous_regressors\n      return model.ModelOutputs(\n          loss=tf.math.reduce_mean(loss_ta.stack(), axis=0),\n          end_state=(new_state_times, new_state_values,\n                     new_exogenous_regressors),\n          predictions={\n              \"mean\":\n                  tf.reshape(\n                      tf.compat.v1.transpose(mean_ta.stack(), [1, 0, 2, 3]),\n                      prediction_shape),\n              \"covariance\":\n                  tf.reshape(\n                      tf.compat.v1.transpose(covariance_ta.stack(),\n                                             [1, 0, 2, 3]), prediction_shape),\n              \"observed\":\n                  values[:, -output_size:]\n          },\n          prediction_times=times[:, -output_size:])\n    else:\n      raise ValueError(\n          \"Unknown mode '{}' passed to get_batch_loss.\".format(mode))\n\n  def _compute_time_features(self, time):\n    \"\"\"Compute some features on the time value.\"\"\"\n    batch_size = tf.compat.v1.shape(time)[0]\n    num_periods = len(self._periodicities)\n    # Reshape to 3D.\n    periods = tf.constant(\n        self._periodicities, shape=[1, 1, num_periods, 1], dtype=time.dtype)\n    time = tf.reshape(time, [batch_size, -1, 1, 1])\n    window_offset = time / self._periodicities\n    # Cast to appropriate type and scale to [0, 1) range\n    mod = (\n        tf.cast(time % periods, self.dtype) * self._buckets /\n        tf.cast(periods, self.dtype))\n    # Bucketize based on some fixed width intervals. For a value t and interval\n    # [a, b), we return (t - a) if a <= t < b, else 0.\n    intervals = tf.reshape(\n        tf.range(self._buckets, dtype=self.dtype), [1, 1, 1, self._buckets])\n    mod = tf.nn.relu(mod - intervals)\n    mod = tf.where(mod < 1.0, mod, tf.compat.v1.zeros_like(mod))\n    return window_offset, mod\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/ar_model_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for ar_model.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import estimator_lib\nfrom tensorflow_estimator.python.estimator.canned.timeseries import ar_model\nfrom tensorflow_estimator.python.estimator.canned.timeseries.estimators import LSTMAutoRegressor\nfrom tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import PredictionFeatures\nfrom tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import TrainEvalFeatures\n\n\n@test_util.run_v1_only(\"Currently incompatible with ResourceVariable\")\nclass ARModelTest(tf.test.TestCase):\n\n  def test_wrong_window_size(self):\n    estimator = LSTMAutoRegressor(\n        periodicities=10,\n        input_window_size=10,\n        output_window_size=6,\n        num_features=1)\n\n    def _bad_window_size_input_fn():\n      return ({\n          TrainEvalFeatures.TIMES: [[1]],\n          TrainEvalFeatures.VALUES: [[[1.]]]\n      }, None)\n\n    def _good_data():\n      return ({\n          TrainEvalFeatures.TIMES: tf.range(16)[None, :],\n          TrainEvalFeatures.VALUES: tf.reshape(tf.range(16), [1, 16, 1])\n      }, None)\n\n    with self.assertRaisesRegexp(ValueError, \"set window_size=16\"):\n      estimator.train(input_fn=_bad_window_size_input_fn, steps=1)\n    # Get a checkpoint for evaluation\n    estimator.train(input_fn=_good_data, steps=1)\n    with self.assertRaisesRegexp(ValueError, \"requires a window of at least\"):\n      estimator.evaluate(input_fn=_bad_window_size_input_fn, steps=1)\n\n  def test_predictions_direct_lstm(self):\n    model = ar_model.ARModel(\n        periodicities=2,\n        num_features=1,\n        num_time_buckets=10,\n        input_window_size=2,\n        output_window_size=2,\n        prediction_model_factory=functools.partial(\n            ar_model.LSTMPredictionModel, num_units=16))\n    with tf.compat.v1.Session():\n      predicted_values = model.predict({\n          PredictionFeatures.TIMES: [[4, 6, 10]],\n          PredictionFeatures.STATE_TUPLE: ([[1, 2]], [[[1.], [2.]]], [[[], []]])\n      })\n      tf.compat.v1.initializers.global_variables().run()\n      self.assertAllEqual(predicted_values[\"mean\"].eval().shape, [1, 3, 1])\n\n  def test_long_eval(self):\n    model = ar_model.ARModel(\n        periodicities=2,\n        num_features=1,\n        num_time_buckets=10,\n        input_window_size=2,\n        output_window_size=1)\n    raw_features = {\n        TrainEvalFeatures.TIMES: [[1, 3, 5, 7, 11]],\n        TrainEvalFeatures.VALUES: [[[1.], [2.], [3.], [4.], [5.]]]\n    }\n    model.initialize_graph()\n    with tf.compat.v1.variable_scope(\"armodel\"):\n      raw_evaluation = model.define_loss(\n          raw_features, mode=estimator_lib.ModeKeys.EVAL)\n    with tf.compat.v1.Session() as sess:\n      tf.compat.v1.initializers.global_variables().run()\n      raw_evaluation_evaled = sess.run(raw_evaluation)\n      self.assertAllEqual([[5, 7, 11]], raw_evaluation_evaled.prediction_times)\n      for feature_name in raw_evaluation.predictions:\n        self.assertAllEqual(\n            [1, 3, 1],  # batch, window, num_features. The window size has 2\n            # cut off for the first input_window.\n            raw_evaluation_evaled.predictions[feature_name].shape)\n\n  def test_long_eval_discard_indivisible(self):\n    model = ar_model.ARModel(\n        periodicities=2,\n        num_features=1,\n        num_time_buckets=10,\n        input_window_size=2,\n        output_window_size=2)\n    raw_features = {\n        TrainEvalFeatures.TIMES: [[1, 3, 5, 7, 11]],\n        TrainEvalFeatures.VALUES: [[[1.], [2.], [3.], [4.], [5.]]]\n    }\n    model.initialize_graph()\n    raw_evaluation = model.define_loss(\n        raw_features, mode=estimator_lib.ModeKeys.EVAL)\n    with tf.compat.v1.Session() as sess:\n      tf.compat.v1.initializers.global_variables().run()\n      raw_evaluation_evaled = sess.run(raw_evaluation)\n      self.assertAllEqual([[7, 11]], raw_evaluation_evaled.prediction_times)\n      for feature_name in raw_evaluation.predictions:\n        self.assertAllEqual(\n            [1, 2, 1],  # batch, window, num_features. The window has two cut\n            # off for the first input window and one discarded so\n            # that the remainder is divisible into output windows.\n            raw_evaluation_evaled.predictions[feature_name].shape)\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/ar_model_training_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for training ar_model.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import estimator_lib\nfrom tensorflow_estimator.python.estimator.canned.timeseries import ar_model\nfrom tensorflow_estimator.python.estimator.canned.timeseries.estimators import LSTMAutoRegressor\nfrom tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import PredictionFeatures\nfrom tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import TrainEvalFeatures\n\n\nclass InputFnBuilder(object):\n\n  def __init__(self,\n               noise_stddev,\n               periods,\n               window_size,\n               batch_size,\n               num_samples=200):\n    self.window_size = window_size\n    self.batch_size = batch_size\n\n    split = int(num_samples * 0.8)\n    self.initialize_data = lambda: self.initialize_data_with_properties(\n        noise_stddev, periods, num_samples, split)\n\n  def initialize_data_with_properties(self, noise_stddev, periods, num_samples,\n                                      split):\n    time = 1 + 3 * tf.range(num_samples, dtype=tf.dtypes.int64)\n    time_offset = 2 * math.pi * tf.cast(time % periods[0],\n                                        tf.dtypes.float32) / periods[0]\n    time_offset = time_offset[:, None]\n    if len(periods) > 1:\n      time_offset2 = tf.cast(time % periods[1], tf.dtypes.float32) / periods[1]\n      time_offset2 = time_offset2[:, None]\n      data1 = tf.math.sin(time_offset / 2.0)**2 * (1 + time_offset2)\n    else:\n      data1 = tf.math.sin(2 * time_offset) + tf.math.cos(3 * time_offset)\n    data1_noise = \\\n      noise_stddev / 4. * tf.random.normal([num_samples], 1)[:, None]\n    data1 = tf.math.add(data1, data1_noise)\n\n    data2 = tf.math.sin(3 * time_offset) + tf.math.cos(5 * time_offset)\n    data2_noise = \\\n      noise_stddev / 3. * tf.random.normal([num_samples], 1)[:, None]\n    data2 = tf.math.add(data2, data2_noise)\n    data = tf.concat((4 * data1, 3 * data2), 1)\n    self.train_data, self.test_data = data[0:split], data[split:]\n    self.train_time, self.test_time = time[0:split], time[split:]\n\n  def train_or_test_input_fn(self, time, data):\n\n    def map_to_dict(time, data):\n      return {TrainEvalFeatures.TIMES: time, TrainEvalFeatures.VALUES: data}\n\n    def batch_windows(time, data):\n      return tf.compat.v1.data.Dataset.zip((time, data)).batch(\n          self.window_size, drop_remainder=True)\n\n    dataset = tf.compat.v1.data.Dataset.from_tensor_slices((time, data))\n    dataset = dataset.window(self.window_size, shift=1, drop_remainder=True)\n    dataset = dataset.shuffle(1000, seed=2).repeat()\n    dataset = dataset.flat_map(batch_windows).batch(\n        self.batch_size).map(map_to_dict)\n    return dataset\n\n  def train_input_fn(self):\n    self.initialize_data()\n    return self.train_or_test_input_fn(self.train_time, self.train_data)\n\n  def test_input_fn(self):\n    self.initialize_data()\n    return self.train_or_test_input_fn(self.test_time, self.test_data)\n\n  def prediction_input_fn(self):\n\n    def map_to_dict(predict_times, predict_true_values, state_times,\n                    state_values, state_exogenous):\n      return ({\n          PredictionFeatures.TIMES:\n              predict_times[None, :],\n          TrainEvalFeatures.VALUES:\n              predict_true_values[None, :],\n          PredictionFeatures.STATE_TUPLE:\n              (state_times[None, :], state_values[None, :],\n               state_exogenous[None, :])\n      }, {})\n\n    self.initialize_data()\n    predict_times = tf.concat(\n        [self.train_time[self.window_size:], self.test_time], 0)[None, :]\n    predict_true_values = tf.concat(\n        [self.train_data[self.window_size:], self.test_data], 0)[None, :]\n    state_times = tf.cast(self.train_time[:self.window_size][None, :],\n                          tf.dtypes.float32)\n    state_values = tf.cast(self.train_data[:self.window_size, :][None, :],\n                           tf.dtypes.float32)\n    state_exogenous = state_times[:, :, None][:, :, :0]\n\n    dataset = tf.compat.v1.data.Dataset.from_tensor_slices(\n        (predict_times, predict_true_values, state_times, state_values,\n         state_exogenous))\n    dataset = dataset.map(map_to_dict)\n    return dataset\n\n  def true_values(self):\n    self.initialize_data()\n    predict_true_values = tf.concat(\n        [self.train_data[self.window_size:], self.test_data], 0)[None, :]\n    true_values = predict_true_values[0, :, 0]\n    return true_values\n\n\n@test_util.run_v1_only(\"Currently incompatible with ResourceVariable\")\nclass ARModelTrainingTest(tf.test.TestCase):\n\n  def train_helper(self, input_window_size, loss, max_loss=None, periods=(25,)):\n    data_noise_stddev = 0.2\n    if max_loss is None:\n      if loss == ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS:\n        max_loss = 1.0\n      else:\n        max_loss = 0.05 / (data_noise_stddev**2)\n    output_window_size = 10\n    window_size = input_window_size + output_window_size\n    input_fn_builder = InputFnBuilder(\n        noise_stddev=data_noise_stddev,\n        periods=periods,\n        window_size=window_size,\n        batch_size=64)\n\n    class _RunConfig(estimator_lib.RunConfig):\n\n      @property\n      def tf_random_seed(self):\n        return 3\n\n    estimator = LSTMAutoRegressor(\n        periodicities=periods,\n        input_window_size=input_window_size,\n        output_window_size=output_window_size,\n        num_features=2,\n        num_timesteps=20,\n        num_units=16,\n        loss=loss,\n        config=_RunConfig())\n\n    # Test training\n    # Note that most models will require many more steps to fully converge. We\n    # have used a small number of steps here to keep the running time small.\n    estimator.train(input_fn=input_fn_builder.train_input_fn, steps=75)\n    test_evaluation = estimator.evaluate(\n        input_fn=input_fn_builder.test_input_fn, steps=1)\n    test_loss = test_evaluation[\"loss\"]\n    tf.compat.v1.logging.warn(\"Final test loss: %f\", test_loss)\n    self.assertLess(test_loss, max_loss)\n    if loss == ar_model.ARModel.SQUARED_LOSS:\n      # Test that the evaluation loss is reported without input scaling.\n      self.assertAllClose(\n          test_loss,\n          tf.math.reduce_mean(\n              (test_evaluation[\"mean\"] - test_evaluation[\"observed\"])**2))\n\n    # Test predict\n    (predictions,) = tuple(\n        estimator.predict(input_fn=input_fn_builder.prediction_input_fn))\n    predicted_mean = predictions[\"mean\"][:, 0]\n\n    if loss == ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS:\n      variances = predictions[\"covariance\"][:, 0]\n      standard_deviations = tf.math.sqrt(variances)\n      # Note that we may get tighter bounds with more training steps.\n      true_values = input_fn_builder.true_values()\n      errors = tf.math.abs(predicted_mean -\n                           true_values) > 4 * standard_deviations\n      fraction_errors = tf.math.reduce_mean(tf.cast(errors, tf.dtypes.float32))\n      tf.compat.v1.logging.warn(\"Fraction errors: %f\",\n                                self.evaluate(fraction_errors))\n\n  def test_autoregression_squared(self):\n    self.train_helper(input_window_size=15, loss=ar_model.ARModel.SQUARED_LOSS)\n\n  def test_autoregression_short_input_window(self):\n    self.train_helper(input_window_size=8, loss=ar_model.ARModel.SQUARED_LOSS)\n\n  def test_autoregression_normal(self):\n    self.train_helper(\n        input_window_size=10,\n        loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,\n        max_loss=50.)  # Just make sure there are no exceptions.\n\n  def test_autoregression_normal_multiple_periods(self):\n    self.train_helper(\n        input_window_size=10,\n        loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,\n        max_loss=2.0,\n        periods=(25, 55))\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/estimators.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Estimators for time series models.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import estimator_lib\nfrom tensorflow_estimator.python.estimator.canned import optimizers\nfrom tensorflow_estimator.python.estimator.canned.timeseries import ar_model\nfrom tensorflow_estimator.python.estimator.canned.timeseries import feature_keys\nfrom tensorflow_estimator.python.estimator.canned.timeseries import head as ts_head_lib\nfrom tensorflow_estimator.python.estimator.canned.timeseries import math_utils\nfrom tensorflow_estimator.python.estimator.canned.timeseries import state_management\nfrom tensorflow_estimator.python.estimator.export import export_lib\n\n\nclass TimeSeriesRegressor(estimator_lib.Estimator):\n  \"\"\"An Estimator to fit and evaluate a time series model.\"\"\"\n\n  def __init__(self,\n               model,\n               state_manager=None,\n               optimizer=None,\n               model_dir=None,\n               config=None,\n               head_type=ts_head_lib.TimeSeriesRegressionHead):\n    \"\"\"Initialize the Estimator.\n\n    Args:\n      model: The time series model to wrap (inheriting from TimeSeriesModel).\n      state_manager: The state manager to use, or (by default)\n        PassthroughStateManager if none is needed.\n      optimizer: The optimization algorithm to use when training, inheriting\n        from tf.train.Optimizer. Defaults to Adam with step size 0.02.\n      model_dir: See `Estimator`.\n      config: See `Estimator`.\n      head_type: The kind of head to use for the model (inheriting from\n        `TimeSeriesRegressionHead`).\n    \"\"\"\n    input_statistics_generator = math_utils.InputStatisticsFromMiniBatch(\n        dtype=model.dtype, num_features=model.num_features)\n    if state_manager is None:\n      if isinstance(model, ar_model.ARModel):\n        state_manager = state_management.FilteringOnlyStateManager()\n      else:\n        state_manager = state_management.PassthroughStateManager()\n    if optimizer is None:\n      optimizer = tf.compat.v1.train.AdamOptimizer(0.02)\n    self._model = model\n    ts_regression_head = head_type(\n        model=model,\n        state_manager=state_manager,\n        optimizer=optimizer,\n        input_statistics_generator=input_statistics_generator)\n    model_fn = ts_regression_head.create_estimator_spec\n    super(TimeSeriesRegressor, self).__init__(\n        model_fn=model_fn, model_dir=model_dir, config=config)\n\n  def _model_start_state_placeholders(self,\n                                      batch_size_tensor,\n                                      static_batch_size=None):\n    \"\"\"Creates placeholders with zeroed start state for the current model.\"\"\"\n    gathered_state = {}\n    # Models may not know the shape of their state without creating some\n    # variables/ops. Avoid polluting the default graph by making a new one. We\n    # use only static metadata from the returned Tensors.\n    with tf.Graph().as_default():\n      self._model.initialize_graph()\n\n      # Evaluate the initial state as same-dtype \"zero\" values. These zero\n      # constants aren't used, but are necessary for feeding to\n      # placeholder_with_default for the \"cold start\" case where state is not\n      # fed to the model.\n      def _zeros_like_constant(tensor):\n        return tf.get_static_value(tf.compat.v1.zeros_like(tensor))\n\n      start_state = tf.nest.map_structure(_zeros_like_constant,\n                                          self._model.get_start_state())\n    for prefixed_state_name, state in ts_head_lib.state_to_dictionary(\n        start_state).items():\n      state_shape_with_batch = tf.TensorShape(\n          (static_batch_size,)).concatenate(state.shape)\n      default_state_broadcast = tf.tile(\n          state[None, ...],\n          multiples=tf.concat(\n              [batch_size_tensor[None],\n               tf.ones(len(state.shape), dtype=tf.dtypes.int32)],\n              axis=0))\n      gathered_state[\n          prefixed_state_name] = tf.compat.v1.placeholder_with_default(\n              input=default_state_broadcast,\n              name=prefixed_state_name,\n              shape=state_shape_with_batch)\n    return gathered_state\n\n  def build_one_shot_parsing_serving_input_receiver_fn(self,\n                                                       filtering_length,\n                                                       prediction_length,\n                                                       default_batch_size=None,\n                                                       values_input_dtype=None,\n                                                       truncate_values=False):\n    \"\"\"Build an input_receiver_fn for export_saved_model accepting tf.Examples.\n\n    Only compatible with `OneShotPredictionHead` (see `head`).\n\n    Args:\n      filtering_length: The number of time steps used as input to the model, for\n        which values are provided. If more than `filtering_length` values are\n        provided (via `truncate_values`), only the first `filtering_length`\n        values are used.\n      prediction_length: The number of time steps requested as predictions from\n        the model. Times and all exogenous features must be provided for these\n        steps.\n      default_batch_size: If specified, must be a scalar integer. Sets the batch\n        size in the static shape information of all feature Tensors, which means\n        only this batch size will be accepted by the exported model. If None\n        (default), static shape information for batch sizes is omitted.\n      values_input_dtype: An optional dtype specification for values in the\n        tf.Example protos (either float32 or int64, since these are the numeric\n        types supported by tf.Example). After parsing, values are cast to the\n        model's dtype (float32 or float64).\n      truncate_values: If True, expects `filtering_length + prediction_length`\n        values to be provided, but only uses the first `filtering_length`. If\n        False (default), exactly `filtering_length` values must be provided.\n\n    Returns:\n      An input_receiver_fn which may be passed to the Estimator's\n      export_saved_model.\n\n      Expects features contained in a vector of serialized tf.Examples with\n      shape [batch size] (dtype `tf.string`), each tf.Example containing\n      features with the following shapes:\n        times: [filtering_length + prediction_length] integer\n        values: [filtering_length, num features] floating point. If\n          `truncate_values` is True, expects `filtering_length +\n          prediction_length` values but only uses the first `filtering_length`.\n        all exogenous features: [filtering_length + prediction_length, ...]\n          (various dtypes)\n    \"\"\"\n    if values_input_dtype is None:\n      values_input_dtype = tf.dtypes.float32\n    if truncate_values:\n      values_proto_length = filtering_length + prediction_length\n    else:\n      values_proto_length = filtering_length\n\n    def _serving_input_receiver_fn():\n      \"\"\"A receiver function to be passed to export_saved_model.\"\"\"\n      times_column = tf.feature_column.numeric_column(\n          key=feature_keys.TrainEvalFeatures.TIMES, dtype=tf.dtypes.int64)\n      values_column = tf.feature_column.numeric_column(\n          key=feature_keys.TrainEvalFeatures.VALUES,\n          dtype=values_input_dtype,\n          shape=(self._model.num_features,))\n      parsed_features_no_sequence = (\n          tf.compat.v1.feature_column.make_parse_example_spec(\n              list(self._model.exogenous_feature_columns) +\n              [times_column, values_column]))\n      parsed_features = {}\n      for key, feature_spec in parsed_features_no_sequence.items():\n        if isinstance(feature_spec, tf.io.FixedLenFeature):\n          if key == feature_keys.TrainEvalFeatures.VALUES:\n            parsed_features[key] = feature_spec._replace(\n                shape=((values_proto_length,) + feature_spec.shape))\n          else:\n            parsed_features[key] = feature_spec._replace(\n                shape=((filtering_length + prediction_length,) +\n                       feature_spec.shape))\n        elif feature_spec.dtype == tf.dtypes.string:\n          parsed_features[key] = tf.io.FixedLenFeature(\n              shape=(filtering_length + prediction_length,),\n              dtype=tf.dtypes.string)\n        else:  # VarLenFeature\n          raise ValueError(\"VarLenFeatures not supported, got %s for key %s\" %\n                           (feature_spec, key))\n      tfexamples = tf.compat.v1.placeholder(\n          shape=[default_batch_size], dtype=tf.dtypes.string, name=\"input\")\n      features = tf.compat.v1.io.parse_example(\n          serialized=tfexamples, features=parsed_features)\n      features[feature_keys.TrainEvalFeatures.TIMES] = tf.compat.v1.squeeze(\n          features[feature_keys.TrainEvalFeatures.TIMES], axis=-1)\n      features[feature_keys.TrainEvalFeatures.VALUES] = tf.cast(\n          features[feature_keys.TrainEvalFeatures.VALUES],\n          dtype=self._model.dtype)[:, :filtering_length]\n      features.update(\n          self._model_start_state_placeholders(\n              batch_size_tensor=tf.compat.v1.shape(\n                  features[feature_keys.TrainEvalFeatures.TIMES])[0],\n              static_batch_size=default_batch_size))\n      return export_lib.ServingInputReceiver(features, {\"examples\": tfexamples})\n\n    return _serving_input_receiver_fn\n\n  def build_raw_serving_input_receiver_fn(self,\n                                          default_batch_size=None,\n                                          default_series_length=None):\n    \"\"\"Build an input_receiver_fn for export_saved_model which accepts arrays.\n\n    Automatically creates placeholders for exogenous `FeatureColumn`s passed to\n    the model.\n\n    Args:\n      default_batch_size: If specified, must be a scalar integer. Sets the batch\n        size in the static shape information of all feature Tensors, which means\n        only this batch size will be accepted by the exported model. If None\n        (default), static shape information for batch sizes is omitted.\n      default_series_length: If specified, must be a scalar integer. Sets the\n        series length in the static shape information of all feature Tensors,\n        which means only this series length will be accepted by the exported\n        model. If None (default), static shape information for series length is\n        omitted.\n\n    Returns:\n      An input_receiver_fn which may be passed to the Estimator's\n      export_saved_model.\n    \"\"\"\n\n    def _serving_input_receiver_fn():\n      \"\"\"A receiver function to be passed to export_saved_model.\"\"\"\n      placeholders = {}\n      time_placeholder = tf.compat.v1.placeholder(\n          name=feature_keys.TrainEvalFeatures.TIMES,\n          dtype=tf.dtypes.int64,\n          shape=[default_batch_size, default_series_length])\n      placeholders[feature_keys.TrainEvalFeatures.TIMES] = time_placeholder\n      # Values are only necessary when filtering. For prediction the default\n      # value will be ignored.\n      placeholders[feature_keys.TrainEvalFeatures.VALUES] = (\n          tf.compat.v1.placeholder_with_default(\n              name=feature_keys.TrainEvalFeatures.VALUES,\n              input=tf.zeros(\n                  shape=[\n                      default_batch_size if default_batch_size else 0,\n                      default_series_length if default_series_length else 0,\n                      self._model.num_features\n                  ],\n                  dtype=self._model.dtype),\n              shape=(default_batch_size, default_series_length,\n                     self._model.num_features)))\n      if self._model.exogenous_feature_columns:\n        with tf.Graph().as_default():\n          # Default placeholders have only an unknown batch dimension. Make them\n          # in a separate graph, then splice in the series length to the shapes\n          # and re-create them in the outer graph.\n          parsed_features = (\n              tf.compat.v1.feature_column.make_parse_example_spec(\n                  self._model.exogenous_feature_columns))\n          placeholder_features = tf.compat.v1.io.parse_example(\n              serialized=tf.compat.v1.placeholder(\n                  shape=[None], dtype=tf.dtypes.string),\n              features=parsed_features)\n          exogenous_feature_shapes = {\n              key: (value.get_shape(), value.dtype)\n              for key, value in placeholder_features.items()\n          }\n        for feature_key, (batch_only_feature_shape,\n                          value_dtype) in (exogenous_feature_shapes.items()):\n          batch_only_feature_shape = (\n              batch_only_feature_shape.with_rank_at_least(1).as_list())\n          feature_shape = ([default_batch_size, default_series_length] +\n                           batch_only_feature_shape[1:])\n          placeholders[feature_key] = tf.compat.v1.placeholder(\n              dtype=value_dtype, name=feature_key, shape=feature_shape)\n      batch_size_tensor = tf.compat.v1.shape(time_placeholder)[0]\n      placeholders.update(\n          self._model_start_state_placeholders(\n              batch_size_tensor, static_batch_size=default_batch_size))\n      return export_lib.ServingInputReceiver(placeholders, placeholders)\n\n    return _serving_input_receiver_fn\n\n\n# TODO(b/113684821): Add detailed documentation on what the input_fn should do.\n# Add an example of making and returning a Dataset object. Determine if\n# endogenous features can be passed in as FeatureColumns. Move ARModel's loss\n# functions into a more general location.\nclass LSTMAutoRegressor(TimeSeriesRegressor):\n  \"\"\"An Estimator for an LSTM autoregressive model.\n\n  LSTMAutoRegressor is a window-based model, inputting fixed windows of length\n  `input_window_size` and outputting fixed windows of length\n  `output_window_size`. These two parameters must add up to the window_size\n  of data returned by the `input_fn`.\n\n  Each periodicity in the `periodicities` arg is divided by the `num_timesteps`\n  into timesteps that are represented as time features added to the model.\n\n  A good heuristic for picking an appropriate periodicity for a given data set\n  would be the length of cycles in the data. For example, energy usage in a\n  home is typically cyclic each day. If the time feature in a home energy\n  usage dataset is in the unit of hours, then 24 would be an appropriate\n  periodicity. Similarly, a good heuristic for `num_timesteps` is how often the\n  data is expected to change within the cycle. For the aforementioned home\n  energy usage dataset and periodicity of 24, then 48 would be a reasonable\n  value if usage is expected to change every half hour.\n\n  Each feature's value for a given example with time t is the difference\n  between t and the start of the timestep it falls under. If it doesn't fall\n  under a feature's associated timestep, then that feature's value is zero.\n\n  For example: if `periodicities` = (9, 12) and `num_timesteps` = 3, then 6\n  features would be added to the model, 3 for periodicity 9 and 3 for\n  periodicity 12.\n\n  For an example data point where t = 17:\n  - It's in the 3rd timestep for periodicity 9 (2nd period is 9-18 and 3rd\n    timestep is 15-18)\n  - It's in the 2nd timestep for periodicity 12 (2nd period is 12-24 and\n    2nd timestep is between 16-20).\n\n  Therefore the 6 added features for this row with t = 17 would be:\n\n  # Feature name (periodicity#_timestep#), feature value\n  P9_T1, 0 # not in first timestep\n  P9_T2, 0 # not in second timestep\n  P9_T3, 2 # 17 - 15 since 15 is the start of the 3rd timestep\n  P12_T1, 0 # not in first timestep\n  P12_T2, 1 # 17 - 16 since 16 is the start of the 2nd timestep\n  P12_T3, 0 # not in third timestep\n\n  Example Code:\n\n  ```python\n  extra_feature_columns = (\n      feature_column.numeric_column(\"exogenous_variable\"),\n  )\n\n  estimator = LSTMAutoRegressor(\n      periodicities=10,\n      input_window_size=10,\n      output_window_size=5,\n      model_dir=\"/path/to/model/dir\",\n      num_features=1,\n      extra_feature_columns=extra_feature_columns,\n      num_timesteps=50,\n      num_units=10,\n      optimizer=tf.train.ProximalAdagradOptimizer(...))\n\n  # Input builders\n  def input_fn_train():\n    return {\n      \"times\": tf.range(15)[None, :],\n      \"values\": tf.random_normal(shape=[1, 15, 1])\n    }\n  estimator.train(input_fn=input_fn_train, steps=100)\n\n  def input_fn_eval():\n    pass\n  metrics = estimator.evaluate(input_fn=input_fn_eval, steps=10)\n\n  def input_fn_predict():\n    pass\n  predictions = estimator.predict(input_fn=input_fn_predict)\n  ```\n  \"\"\"\n\n  def __init__(self,\n               periodicities,\n               input_window_size,\n               output_window_size,\n               model_dir=None,\n               num_features=1,\n               extra_feature_columns=None,\n               num_timesteps=10,\n               loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS,\n               num_units=128,\n               optimizer=\"Adam\",\n               config=None):\n    \"\"\"Initialize the Estimator.\n\n    Args:\n      periodicities: periodicities of the input data, in the same units as the\n        time feature (for example 24 if feeding hourly data with a daily\n        periodicity, or 60 * 24 if feeding minute-level data with daily\n        periodicity). Note this can be a single value or a list of values for\n        multiple periodicities.\n      input_window_size: Number of past time steps of data to look at when doing\n        the regression.\n      output_window_size: Number of future time steps to predict. Note that\n        setting this value to > 1 empirically seems to give a better fit.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model.\n      num_features: The dimensionality of the time series (default value is one\n        for univariate, more than one for multivariate).\n      extra_feature_columns: A list of `tf.feature_column`s (for example\n        `tf.feature_column.embedding_column`) corresponding to features which\n        provide extra information to the model but are not part of the series to\n        be predicted.\n      num_timesteps: Number of buckets into which to divide (time %\n        periodicity). This value multiplied by the number of periodicities is\n        the number of time features added to the model.\n      loss: Loss function to use for training. Currently supported values are\n        SQUARED_LOSS and NORMAL_LIKELIHOOD_LOSS. Note that for\n        NORMAL_LIKELIHOOD_LOSS, we train the covariance term as well. For\n        SQUARED_LOSS, the evaluation loss is reported based on un-scaled\n        observations and predictions, while the training loss is computed on\n        normalized data.\n      num_units: The size of the hidden state in the encoder and decoder LSTM\n        cells.\n      optimizer: string, `tf.train.Optimizer` object, or callable that defines\n        the optimizer algorithm to use for training. Defaults to the Adam\n        optimizer with a learning rate of 0.01.\n      config: Optional `estimator.RunConfig` object to configure the runtime\n        settings.\n    \"\"\"\n    optimizer = optimizers.get_optimizer_instance(optimizer, learning_rate=0.01)\n    model = ar_model.ARModel(\n        periodicities=periodicities,\n        input_window_size=input_window_size,\n        output_window_size=output_window_size,\n        num_features=num_features,\n        exogenous_feature_columns=extra_feature_columns,\n        num_time_buckets=num_timesteps,\n        loss=loss,\n        prediction_model_factory=functools.partial(\n            ar_model.LSTMPredictionModel, num_units=num_units))\n    state_manager = state_management.FilteringOnlyStateManager()\n    super(LSTMAutoRegressor, self).__init__(\n        model=model,\n        state_manager=state_manager,\n        optimizer=optimizer,\n        model_dir=model_dir,\n        config=config,\n        head_type=ts_head_lib.OneShotPredictionHead)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/estimators_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\nimport tempfile\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.ops import math_ops\nfrom tensorflow_estimator.python.estimator import estimator_lib\nfrom tensorflow_estimator.python.estimator.canned.timeseries import ar_model\nfrom tensorflow_estimator.python.estimator.canned.timeseries import estimators\nfrom tensorflow_estimator.python.estimator.canned.timeseries import feature_keys\nfrom tensorflow_estimator.python.estimator.canned.timeseries import saved_model_utils\n\n\nclass _SeedRunConfig(estimator_lib.RunConfig):\n\n  @property\n  def tf_random_seed(self):\n    return 3\n\n\ndef _generate_data():\n  time = tf.range(20, dtype=tf.dtypes.int64)\n  data = tf.reshape(tf.range(20, dtype=tf.dtypes.float32), (20, 1))\n  exogenous = data\n  return time, data, exogenous\n\n\ndef _build_input_fn_with_seed(seed):\n\n  def map_to_dict(time, data, exogenous):\n    return {\n        feature_keys.TrainEvalFeatures.TIMES: time,\n        feature_keys.TrainEvalFeatures.VALUES: data,\n        \"exogenous\": exogenous\n    }\n\n  def batch_windows(time, data, exogenous):\n    return tf.compat.v1.data.Dataset.zip((time, data, exogenous)).batch(\n        16, drop_remainder=True)\n\n  def input_fn():\n    dataset = tf.compat.v1.data.Dataset.from_tensor_slices(_generate_data())\n    dataset = dataset.window(16, shift=1, drop_remainder=True)\n    dataset = dataset.shuffle(1000, seed=seed).repeat()\n    dataset = dataset.flat_map(batch_windows).batch(16).map(map_to_dict)\n    return dataset\n\n  return input_fn\n\n\n@test_util.run_v1_only(\"Currently incompatible with ResourceVariable\")\nclass TimeSeriesRegressorTest(tf.test.TestCase):\n\n  def _fit_restore_fit_test_template(self, estimator_fn, test_saved_model):\n    \"\"\"Tests restoring previously fit models.\"\"\"\n    temp_dir = self.get_temp_dir()\n    model_dir = tempfile.mkdtemp(dir=temp_dir)\n    exogenous_feature_columns = (tf.feature_column.numeric_column(\"exogenous\"),)\n    first_estimator = estimator_fn(model_dir, exogenous_feature_columns)\n    train_input_fn = _build_input_fn_with_seed(2)\n    eval_input_fn = _build_input_fn_with_seed(3)\n    first_estimator.train(input_fn=train_input_fn, steps=1)\n    first_evaluation = first_estimator.evaluate(input_fn=eval_input_fn, steps=1)\n    first_loss_before_fit = first_evaluation[\"loss\"]\n    self.assertAllEqual(first_loss_before_fit, first_evaluation[\"average_loss\"])\n    self.assertAllEqual([], first_loss_before_fit.shape)\n    first_estimator.train(input_fn=train_input_fn, steps=1)\n    first_loss_after_fit = first_estimator.evaluate(\n        input_fn=eval_input_fn, steps=1)[\"loss\"]\n    self.assertAllEqual([], first_loss_after_fit.shape)\n    second_estimator = estimator_fn(model_dir, exogenous_feature_columns)\n    second_estimator.train(input_fn=train_input_fn, steps=1)\n    second_evaluation = second_estimator.evaluate(\n        input_fn=eval_input_fn, steps=1)\n    exogenous_values_ten_steps = {\n        \"exogenous\": tf.range(10, dtype=tf.dtypes.float32)[None, :, None]\n    }\n    input_receiver_fn = first_estimator.build_raw_serving_input_receiver_fn()\n    export_location = first_estimator.export_saved_model(\n        temp_dir, input_receiver_fn)\n    if not test_saved_model:\n      return\n    with tf.Graph().as_default():\n      with tf.compat.v1.Session() as sess:\n        signatures = tf.compat.v1.saved_model.load(sess,\n                                                   [tf.saved_model.SERVING],\n                                                   export_location)\n        # Test that prediction and filtering can continue from evaluation output\n        _ = saved_model_utils.predict_continuation(\n            continue_from=second_evaluation,\n            steps=10,\n            exogenous_features=exogenous_values_ten_steps,\n            signatures=signatures,\n            session=sess)\n        times, values, _ = _generate_data()\n        first_filtering = saved_model_utils.filter_continuation(\n            continue_from=second_evaluation,\n            features={\n                feature_keys.FilteringFeatures.TIMES: times[None, -1] + 2,\n                feature_keys.FilteringFeatures.VALUES: values[None, -1] + 2.,\n                \"exogenous\": values[None, -1, None] + 12.\n            },\n            signatures=signatures,\n            session=sess)\n        # Test that prediction and filtering can continue from filtering output\n        second_saved_prediction = saved_model_utils.predict_continuation(\n            continue_from=first_filtering,\n            steps=1,\n            exogenous_features={\n                \"exogenous\": tf.range(1, dtype=tf.dtypes.float32)[None, :, None]\n            },\n            signatures=signatures,\n            session=sess)\n        self.assertEqual(\n            times[-1] + 3,\n            tf.compat.v1.squeeze(\n                second_saved_prediction[feature_keys.PredictionResults.TIMES]))\n        saved_model_utils.filter_continuation(\n            continue_from=first_filtering,\n            features={\n                feature_keys.FilteringFeatures.TIMES: times[-1] + 3,\n                feature_keys.FilteringFeatures.VALUES: values[-1] + 3.,\n                \"exogenous\": values[-1, None] + 13.\n            },\n            signatures=signatures,\n            session=sess)\n\n        # Test cold starting\n        six.assertCountEqual(\n            self, [\n                feature_keys.FilteringFeatures.TIMES,\n                feature_keys.FilteringFeatures.VALUES, \"exogenous\"\n            ], signatures.signature_def[\n                feature_keys.SavedModelLabels.COLD_START_FILTER].inputs.keys())\n        batched_times = tf.tile(\n            tf.range(30, dtype=tf.dtypes.int64)[None, :], (10, 1))\n        batched_values = tf.ones([10, 30, 1])\n        state = saved_model_utils.cold_start_filter(\n            signatures=signatures,\n            session=sess,\n            features={\n                feature_keys.FilteringFeatures.TIMES: batched_times,\n                feature_keys.FilteringFeatures.VALUES: batched_values,\n                \"exogenous\": 10. + batched_values\n            })\n        predict_times = math_ops.tile(\n            tf.range(30, 45, dtype=tf.dtypes.int64)[None, :], (10, 1))\n        predictions = saved_model_utils.predict_continuation(\n            continue_from=state,\n            times=predict_times,\n            exogenous_features={\n                \"exogenous\":\n                    math_ops.tile(tf.range(15, dtype=tf.dtypes.float32), (10,))\n                    [None, :, None]\n            },\n            signatures=signatures,\n            session=sess)\n        self.assertAllEqual([10, 15, 1], predictions[\"mean\"].shape)\n\n  def disabled_test_time_series_regressor(self):\n\n    def _estimator_fn(model_dir, exogenous_feature_columns):\n      return estimators.TimeSeriesRegressor(\n          model=ar_model.ARModel(\n              periodicities=10,\n              input_window_size=10,\n              output_window_size=6,\n              num_features=1,\n              exogenous_feature_columns=exogenous_feature_columns,\n              prediction_model_factory=functools.partial(\n                  ar_model.LSTMPredictionModel, num_units=10)),\n          config=_SeedRunConfig(),\n          model_dir=model_dir)\n\n    self._fit_restore_fit_test_template(_estimator_fn, test_saved_model=True)\n\n  def test_ar_lstm_regressor(self):\n\n    def _estimator_fn(model_dir, exogenous_feature_columns):\n      return estimators.LSTMAutoRegressor(\n          periodicities=10,\n          input_window_size=10,\n          output_window_size=6,\n          model_dir=model_dir,\n          num_features=1,\n          extra_feature_columns=exogenous_feature_columns,\n          num_units=10,\n          config=_SeedRunConfig())\n\n    # LSTMAutoRegressor uses OneShotPredictionHead which does not work with\n    # saved models.\n    self._fit_restore_fit_test_template(_estimator_fn, test_saved_model=False)\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/feature_keys.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Commonly used special feature names for time series models.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\n\n\nclass State(object):\n  \"\"\"Key formats for accepting/returning state.\"\"\"\n  # The model-dependent state to start from, as a single tuple.\n  STATE_TUPLE = \"start_tuple\"\n  # Same meaning as STATE_TUPLE, but prefixes keys representing flattened model\n  # state rather than mapping to a nested tuple containing model state,\n  # primarily for use with export_saved_model.\n  STATE_PREFIX = \"model_state\"\n\n\nclass Times(object):\n  \"\"\"Key formats for accepting/returning times.\"\"\"\n  # An increasing vector of integers.\n  TIMES = \"times\"\n\n\nclass Values(object):\n  \"\"\"Key formats for accepting/returning values.\"\"\"\n  # Floating point, with one or more values corresponding to each time in TIMES.\n  VALUES = \"values\"\n\n\nclass TrainEvalFeatures(Times, Values):\n  \"\"\"Feature names used during training and evaluation.\"\"\"\n  pass\n\n\nclass PredictionFeatures(Times, State):\n  \"\"\"Feature names used during prediction.\"\"\"\n  pass\n\n\nclass FilteringFeatures(Times, Values, State):\n  \"\"\"Special feature names for filtering.\"\"\"\n  pass\n\n\nclass PredictionResults(Times):\n  \"\"\"Keys returned when predicting (not comprehensive).\"\"\"\n  pass\n\n\nclass FilteringResults(Times, State):\n  \"\"\"Keys returned from evaluation/filtering.\"\"\"\n  pass\n\n\nclass SavedModelLabels(object):\n  \"\"\"Names of signatures exported with export_saved_model.\"\"\"\n  PREDICT = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n  FILTER = \"filter\"\n  COLD_START_FILTER = \"cold_start_filter\"\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/head.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Timeseries head.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport re\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator import estimator_lib\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned.timeseries import feature_keys\nfrom tensorflow_estimator.python.estimator.export import export_lib\n\n\nclass _NoStatePredictOutput(export_lib.PredictOutput):\n\n  def as_signature_def(self, receiver_tensors):\n    no_state_receiver_tensors = {\n        key: value\n        for key, value in receiver_tensors.items()\n        if not key.startswith(feature_keys.State.STATE_PREFIX)\n    }\n    return super(\n        _NoStatePredictOutput,\n        self).as_signature_def(receiver_tensors=no_state_receiver_tensors)\n\n\nclass TimeSeriesRegressionHead(head_lib._Head):  # pylint:disable=protected-access\n  \"\"\"Determines input and output signatures for a time series model.\"\"\"\n\n  def __init__(self,\n               model,\n               state_manager,\n               optimizer,\n               input_statistics_generator=None,\n               name=None):\n    \"\"\"Creates a `_Head` for time series regression.\n\n    Args:\n      model: A model for time series regression.\n      state_manager: A state manager.\n      optimizer: An optimizer.\n      input_statistics_generator: A input statistics generator.\n      name: An optional name for the model.\n    \"\"\"\n    self.model = model\n    self.state_manager = state_manager\n    self.optimizer = optimizer\n    self.input_statistics_generator = input_statistics_generator\n    self._name = name\n\n  @property\n  def name(self):\n    return self._name\n\n  # TODO(terrytangyuan): consolidate `model_outputs` and `_Head.LossSpec`\n  # once `_Head.create_loss` becomes extendable\n  def create_loss(self, features, mode, logits=None, labels=None):\n    \"\"\"See `_Head`.\"\"\"\n    model_outputs = self.state_manager.define_loss(self.model, features, mode)\n    tf.compat.v1.summary.scalar(\n        head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS),\n        model_outputs.loss)\n    return model_outputs\n\n  @property\n  def logits_dimension(self):\n    \"\"\"See `_Head`.\"\"\"\n    return 1\n\n  def _train_ops(self, features):\n    \"\"\"Add training ops to the graph.\"\"\"\n    mode = estimator_lib.ModeKeys.TRAIN\n    with tf.compat.v1.variable_scope(\n        \"model\",\n        # Use ResourceVariables to avoid race conditions.\n        use_resource=True):\n      model_outputs = self.create_loss(features, mode)\n\n    train_op = self.optimizer.minimize(\n        model_outputs.loss, global_step=tf.compat.v1.train.get_global_step())\n    return estimator_lib.EstimatorSpec(\n        loss=model_outputs.loss, mode=mode, train_op=train_op)\n\n  def _evaluate_ops(self, features):\n    \"\"\"Add ops for evaluation (aka filtering) to the graph.\"\"\"\n    mode = estimator_lib.ModeKeys.EVAL\n    with tf.compat.v1.variable_scope(\"model\", use_resource=True):\n      model_outputs = self.create_loss(features, mode)\n    metrics = {}\n    # Just output in-sample predictions for the last chunk seen\n    for prediction_key, prediction_value in model_outputs.predictions.items():\n      metrics[prediction_key] = _identity_metric_single(prediction_key,\n                                                        prediction_value)\n    metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single(\n        feature_keys.FilteringResults.TIMES, model_outputs.prediction_times)\n    metrics[feature_keys.FilteringResults.STATE_TUPLE] = (\n        _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE,\n                                model_outputs.end_state))\n    metrics[metric_keys.MetricKeys.LOSS_MEAN] = tf.compat.v1.metrics.mean(\n        model_outputs.loss, name=\"average_loss\")\n    return estimator_lib.EstimatorSpec(\n        loss=model_outputs.loss,\n        mode=mode,\n        eval_metric_ops=metrics,\n        # needed for custom metrics.\n        predictions=model_outputs.predictions)\n\n  def _predict_ops(self, features):\n    \"\"\"Add ops for prediction to the graph.\"\"\"\n    with tf.compat.v1.variable_scope(\"model\", use_resource=True):\n      prediction = self.model.predict(features=features)\n    prediction[feature_keys.PredictionResults.TIMES] = features[\n        feature_keys.PredictionFeatures.TIMES]\n    return estimator_lib.EstimatorSpec(\n        predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT)\n\n  def _serving_ops(self, features):\n    \"\"\"Add ops for serving to the graph.\"\"\"\n    with tf.compat.v1.variable_scope(\"model\", use_resource=True):\n      prediction_outputs = self.model.predict(features=features)\n    with tf.compat.v1.variable_scope(\"model\", reuse=True):\n      filtering_outputs = self.create_loss(features,\n                                           estimator_lib.ModeKeys.EVAL)\n    with tf.compat.v1.variable_scope(\"model\", reuse=True):\n      no_state_features = {\n          k: v\n          for k, v in features.items()\n          if not k.startswith(feature_keys.State.STATE_PREFIX)\n      }\n      # Ignore any state management when cold-starting. The model's default\n      # start state is replicated across the batch.\n      cold_filtering_outputs = self.model.define_loss(\n          features=no_state_features, mode=estimator_lib.ModeKeys.EVAL)\n    return estimator_lib.EstimatorSpec(\n        mode=estimator_lib.ModeKeys.PREDICT,\n        export_outputs={\n            feature_keys.SavedModelLabels.PREDICT:\n                export_lib.PredictOutput(prediction_outputs),\n            feature_keys.SavedModelLabels.FILTER:\n                export_lib.PredictOutput(\n                    state_to_dictionary(filtering_outputs.end_state)),\n            feature_keys.SavedModelLabels.COLD_START_FILTER:\n                _NoStatePredictOutput(\n                    state_to_dictionary(cold_filtering_outputs.end_state))\n        },\n        # Likely unused, but it is necessary to return `predictions` to satisfy\n        # the Estimator's error checking.\n        predictions={})\n\n  def _convert_feature_to_tensor(self, name, value):\n    \"\"\"Casts features to the correct dtype based on their name.\"\"\"\n    if name in [\n        feature_keys.TrainEvalFeatures.TIMES,\n        feature_keys.PredictionFeatures.TIMES\n    ]:\n      return tf.cast(value, tf.dtypes.int64)\n    if name == feature_keys.TrainEvalFeatures.VALUES:\n      return tf.cast(value, self.model.dtype)\n    if name == feature_keys.PredictionFeatures.STATE_TUPLE:\n      return value  # Correct dtypes are model-dependent\n    return tf.compat.v1.convert_to_tensor_or_sparse_tensor(value)\n\n  def _gather_state(self, features):\n    \"\"\"Returns `features` with state packed, indicates if packing was done.\"\"\"\n    prefixed_state_re = re.compile(r\"^\" + feature_keys.State.STATE_PREFIX +\n                                   r\"_(\\d+)$\")\n    numbered_state = []\n    for key, tensor in features.items():\n      search_result = prefixed_state_re.search(key)\n      if search_result:\n        numbered_state.append((int(search_result.group(1)), key, tensor))\n    if not numbered_state:\n      return features, False\n    features = features.copy()\n    for _, key, _ in numbered_state:\n      del features[key]\n    numbered_state.sort(key=lambda number, *_: number)\n    features[feature_keys.State.STATE_TUPLE] = tf.nest.pack_sequence_as(\n        structure=self.model.get_start_state(),\n        flat_sequence=[tensor for _, _, tensor in numbered_state])\n    return features, True\n\n  def _check_predict_features(self, features):\n    \"\"\"Raises errors if features are not suitable for prediction.\"\"\"\n    if feature_keys.PredictionFeatures.TIMES not in features:\n      raise ValueError(\"Expected a '{}' feature for prediction.\".format(\n          feature_keys.PredictionFeatures.TIMES))\n    if feature_keys.PredictionFeatures.STATE_TUPLE not in features:\n      raise ValueError(\"Expected a '{}' feature for prediction.\".format(\n          feature_keys.PredictionFeatures.STATE_TUPLE))\n    times_feature = features[feature_keys.PredictionFeatures.TIMES]\n    if not times_feature.get_shape().is_compatible_with([None, None]):\n      raise ValueError(\n          (\"Expected shape (batch dimension, window size) for feature '{}' \"\n           \"(got shape {})\").format(feature_keys.PredictionFeatures.TIMES,\n                                    times_feature.get_shape()))\n    _check_feature_shapes_compatible_with(\n        features=features,\n        compatible_with_name=feature_keys.PredictionFeatures.TIMES,\n        compatible_with_value=times_feature,\n        ignore=set([\n            # Model-dependent shapes\n            feature_keys.PredictionFeatures.STATE_TUPLE\n        ]))\n\n  def create_estimator_spec(self, features, mode, labels=None):\n    \"\"\"Performs basic error checking and returns an EstimatorSpec.\"\"\"\n    with ops.name_scope(self._name, \"head\"):\n      # for better error messages.\n      if labels is not None and not (isinstance(labels, dict) and labels == {}):  # pylint: disable=g-explicit-bool-comparison\n        raise ValueError(\n            \"The model received a `labels`, which is not supported. \"\n            \"Pass '{}' and '{}' as features.\".format(\n                feature_keys.TrainEvalFeatures.TIMES,\n                feature_keys.TrainEvalFeatures.VALUES))\n      del labels\n      features = {\n          name: self._convert_feature_to_tensor(name=name, value=value)\n          for name, value in features.items()\n      }\n      if self.input_statistics_generator is not None:\n        input_statistics = self.input_statistics_generator.initialize_graph(\n            features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN))\n      else:\n        input_statistics = None\n      self.model.initialize_graph(input_statistics=input_statistics)\n\n      # _gather_state requires the model to have its graph initialized (so it\n      # has access to the structure of the model's state)\n      features, passed_flat_state = self._gather_state(features)\n      if (mode == estimator_lib.ModeKeys.TRAIN or\n          mode == estimator_lib.ModeKeys.EVAL):\n        _check_train_eval_features(features, self.model)\n      elif mode == estimator_lib.ModeKeys.PREDICT:\n        self._check_predict_features(features)\n      else:\n        raise ValueError(\"Unknown mode '{}' passed to model_fn.\".format(mode))\n\n      self.state_manager.initialize_graph(\n          model=self.model, input_statistics=input_statistics)\n\n      if mode == estimator_lib.ModeKeys.TRAIN:\n        return self._train_ops(features)\n      elif mode == estimator_lib.ModeKeys.EVAL:\n        return self._evaluate_ops(features)\n      elif mode == estimator_lib.ModeKeys.PREDICT and not passed_flat_state:\n        return self._predict_ops(features)\n      elif mode == estimator_lib.ModeKeys.PREDICT and passed_flat_state:\n        # The mode is PREDICT, but we're actually in export_saved_model for\n        # serving. We want to return two graphs: one for filtering (state + data\n        # -> state) and one for predicting (state -> prediction).\n        return self._serving_ops(features)\n\n\nclass OneShotPredictionHead(TimeSeriesRegressionHead):\n  \"\"\"A time series head which exports a single stateless serving signature.\n\n  The serving default signature exported by this head expects `times`, `values`,\n  and any exogenous features, but no state. `values` has shape `[batch_size,\n  filter_length, num_features]` and `times` has shape `[batch_size,\n  total_length]`, where `total_length > filter_length`. Any exogenous features\n  must have their shapes prefixed by the shape of the `times` feature.\n\n  When serving, first performs filtering on the series up to `filter_length`\n  starting from the default start state for the model, then computes predictions\n  on the remainder of the series, returning them.\n\n  Model state is neither accepted nor returned, so filtering must be performed\n  each time predictions are requested when using this head.\n  \"\"\"\n\n  def _check_predict_features(self, features):\n    \"\"\"Raises errors if features are not suitable for one-shot prediction.\"\"\"\n    if feature_keys.PredictionFeatures.TIMES not in features:\n      raise ValueError(\"Expected a '{}' feature for prediction.\".format(\n          feature_keys.PredictionFeatures.TIMES))\n    if feature_keys.TrainEvalFeatures.VALUES not in features:\n      raise ValueError(\"Expected a '{}' feature for prediction.\".format(\n          feature_keys.TrainEvalFeatures.VALUES))\n    if feature_keys.PredictionFeatures.STATE_TUPLE not in features:\n      raise ValueError(\"Expected a '{}' feature for prediction.\".format(\n          feature_keys.PredictionFeatures.STATE_TUPLE))\n    times_feature = features[feature_keys.PredictionFeatures.TIMES]\n    if not times_feature.get_shape().is_compatible_with([None, None]):\n      raise ValueError(\n          (\"Expected shape (batch dimension, window size) for feature '{}' \"\n           \"(got shape {})\").format(feature_keys.PredictionFeatures.TIMES,\n                                    times_feature.get_shape()))\n    _check_feature_shapes_compatible_with(\n        features=features,\n        compatible_with_name=feature_keys.PredictionFeatures.TIMES,\n        compatible_with_value=times_feature,\n        ignore=set([\n            # Model-dependent shapes\n            feature_keys.PredictionFeatures.STATE_TUPLE,\n            # One shot prediction head relies on values being shorter than\n            # times. Even though we're predicting eventually, we need values for\n            # the filtering phase.\n            feature_keys.TrainEvalFeatures.VALUES,\n        ]))\n\n  def _evaluate_ops(self, features):\n    \"\"\"Add ops for evaluation (aka filtering) to the graph.\"\"\"\n    spec = super(OneShotPredictionHead, self)._evaluate_ops(features)\n    # No state is fed to OneShotPredictionHead, so we don't return it; it being\n    # a tuple can cause issues for downstream infrastructure.\n    del spec.eval_metric_ops[feature_keys.State.STATE_TUPLE]\n    return spec\n\n  def _serving_ops(self, features):\n    \"\"\"Add ops for serving to the graph.\"\"\"\n    with tf.compat.v1.variable_scope(\"model\", use_resource=True):\n      filtering_features = {}\n      prediction_features = {}\n      values_length = tf.compat.v1.shape(\n          features[feature_keys.FilteringFeatures.VALUES])[1]\n      for key, value in features.items():\n        if key == feature_keys.State.STATE_TUPLE:\n          # Ignore state input. The model's default start state is replicated\n          # across the batch.\n          continue\n        if key == feature_keys.FilteringFeatures.VALUES:\n          filtering_features[key] = value\n        else:\n          filtering_features[key] = value[:, :values_length]\n          prediction_features[key] = value[:, values_length:]\n      cold_filtering_outputs = self.model.define_loss(\n          features=filtering_features, mode=estimator_lib.ModeKeys.EVAL)\n      prediction_features[feature_keys.State.STATE_TUPLE] = (\n          cold_filtering_outputs.end_state)\n    with tf.compat.v1.variable_scope(\"model\", reuse=True):\n      prediction_outputs = self.model.predict(features=prediction_features)\n    return estimator_lib.EstimatorSpec(\n        mode=estimator_lib.ModeKeys.PREDICT,\n        export_outputs={\n            feature_keys.SavedModelLabels.PREDICT:\n                _NoStatePredictOutput(prediction_outputs),\n        },\n        # Likely unused, but it is necessary to return `predictions` to satisfy\n        # the Estimator's error checking.\n        predictions={})\n\n\ndef _check_feature_shapes_compatible_with(features,\n                                          compatible_with_name,\n                                          compatible_with_value,\n                                          ignore=None):\n  \"\"\"Checks all features are compatible with the given time-like feature.\"\"\"\n  if ignore is None:\n    ignore = set()\n  for name, value in features.items():\n    if name in ignore:\n      continue\n    feature_shape = value.get_shape()\n    if feature_shape.ndims is None:\n      continue\n    if feature_shape.ndims < 2:\n      raise ValueError(\n          (\"Features must have shape (batch dimension, window size, ...) \"\n           \"(got rank {} for feature '{}')\").format(feature_shape.ndims, name))\n    if not feature_shape[:2].is_compatible_with(\n        compatible_with_value.get_shape()):\n      raise ValueError(\n          (\"Features must have shape (batch dimension, window size, ...) \"\n           \"where batch dimension and window size match the \"\n           \"'{times_feature}' feature (got shape {feature_shape} for \"\n           \"feature '{feature_name}' but shape {times_shape} for feature \"\n           \"'{times_feature}')\").format(\n               times_feature=compatible_with_name,\n               feature_shape=feature_shape,\n               feature_name=name,\n               times_shape=compatible_with_value.get_shape()))\n\n\ndef _check_train_eval_features(features, model):\n  \"\"\"Raise errors if features are not suitable for training/evaluation.\"\"\"\n  if feature_keys.TrainEvalFeatures.TIMES not in features:\n    raise ValueError(\"Expected a '{}' feature for training/evaluation.\".format(\n        feature_keys.TrainEvalFeatures.TIMES))\n  if feature_keys.TrainEvalFeatures.VALUES not in features:\n    raise ValueError(\"Expected a '{}' feature for training/evaluation.\".format(\n        feature_keys.TrainEvalFeatures.VALUES))\n  times_feature = features[feature_keys.TrainEvalFeatures.TIMES]\n  if not times_feature.get_shape().is_compatible_with([None, None]):\n    raise ValueError(\n        (\"Expected shape (batch dimension, window size) for feature '{}' \"\n         \"(got shape {})\").format(feature_keys.TrainEvalFeatures.TIMES,\n                                  times_feature.get_shape()))\n  values_feature = features[feature_keys.TrainEvalFeatures.VALUES]\n  if not values_feature.get_shape().is_compatible_with(\n      [None, None, model.num_features]):\n    raise ValueError(\n        (\"Expected shape (batch dimension, window size, {num_features}) \"\n         \"for feature '{feature_name}', since the model was configured \"\n         \"with num_features={num_features} (got shape {got_shape})\").format(\n             num_features=model.num_features,\n             feature_name=feature_keys.TrainEvalFeatures.VALUES,\n             got_shape=times_feature.get_shape()))\n  _check_feature_shapes_compatible_with(\n      features=features,\n      compatible_with_name=feature_keys.TrainEvalFeatures.TIMES,\n      compatible_with_value=times_feature,\n      ignore=set([\n          feature_keys.State.STATE_TUPLE  # Model-dependent shapes\n      ]))\n\n\ndef _identity_metric_single(name, input_tensor):\n  \"\"\"A metric which takes on its last updated value.\n\n  This keeps evaluation metrics in sync with one another, since update ops are\n  run separately from their result Tensors. Simply returning (input_tensor,\n  no_op) as a metric with a value but no update means that a metric will come\n  from a different batch of data than metrics which cache values in a Variable\n  (e.g. the default loss metric).\n\n  Args:\n    name: A name for the metric.\n    input_tensor: Any Tensor.\n\n  Returns:\n    A tuple of (value, update_op).\n  \"\"\"\n  metric_variable = tf.compat.v1.Variable(\n      name=\"{}_identity_metric\".format(name),\n      initial_value=tf.zeros([], dtype=input_tensor.dtype),\n      collections=[tf.compat.v1.GraphKeys.LOCAL_VARIABLES],\n      validate_shape=False)\n  update_op = tf.compat.v1.assign(\n      metric_variable, input_tensor, validate_shape=False)\n  # This shape will be correct once the first update runs (but may be\n  # incomplete, so is not helpful for initializing the variable).\n  metric_variable.set_shape(input_tensor.get_shape())\n  return (metric_variable.value(), update_op)\n\n\ndef _identity_metric_nested(name, input_tensors):\n  \"\"\"Create identity metrics for a nested tuple of Tensors.\"\"\"\n  update_ops = []\n  value_tensors = []\n  for tensor_number, tensor in enumerate(tf.nest.flatten(input_tensors)):\n    value_tensor, update_op = _identity_metric_single(\n        name=\"{}_{}\".format(name, tensor_number), input_tensor=tensor)\n    update_ops.append(update_op)\n    value_tensors.append(value_tensor)\n  return (tf.nest.pack_sequence_as(input_tensors, value_tensors),\n          tf.group(*update_ops))\n\n\ndef state_to_dictionary(state_tuple):\n  \"\"\"Flatten model state into a dictionary with string keys.\"\"\"\n  flattened = {}\n  for state_number, state_value in enumerate(tf.nest.flatten(state_tuple)):\n    prefixed_state_name = \"{}_{:02d}\".format(feature_keys.State.STATE_PREFIX,\n                                             state_number)\n    flattened[prefixed_state_name] = state_value\n  return flattened\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/head_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for head.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\nimport os\n\nfrom absl.testing import parameterized\nimport numpy\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import estimator_lib\nfrom tensorflow_estimator.python.estimator import extenders\nfrom tensorflow_estimator.python.estimator.canned.timeseries import ar_model\nfrom tensorflow_estimator.python.estimator.canned.timeseries import estimators as ts_estimators\nfrom tensorflow_estimator.python.estimator.canned.timeseries import feature_keys\nfrom tensorflow_estimator.python.estimator.canned.timeseries import head as ts_head_lib\nfrom tensorflow_estimator.python.estimator.canned.timeseries import model\nfrom tensorflow_estimator.python.estimator.canned.timeseries import state_management\n\n\nclass HeadTest(tf.test.TestCase):\n\n  def test_labels_provided_error(self):\n    model_fn = _stub_model_fn()\n    for mode in [\n        estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL,\n        estimator_lib.ModeKeys.PREDICT\n    ]:\n      with self.assertRaisesRegexp(ValueError, \"received a `labels`\"):\n        model_fn(features={}, labels={\"a\": \"b\"}, mode=mode)\n\n      with self.assertRaisesRegexp(ValueError, \"received a `labels`\"):\n        model_fn(features={}, labels=tf.zeros([]), mode=mode)\n\n  def test_unknown_mode(self):\n    model_fn = _stub_model_fn()\n    with self.assertRaisesRegexp(ValueError, \"Unknown mode 'Not a mode'\"):\n      model_fn(features={}, labels={}, mode=\"Not a mode\")\n\n\nclass _TickerModel(object):\n  num_features = 1\n  dtype = tf.dtypes.float32\n\n  def initialize_graph(self, input_statistics):\n    pass\n\n  def define_loss(self, features, mode):\n    del mode  # unused\n    return model.ModelOutputs(\n        loss=features[\"ticker\"],\n        end_state=(features[\"ticker\"], features[\"ticker\"]),\n        prediction_times=tf.zeros(()),\n        predictions={\"ticker\": features[\"ticker\"]})\n\n\n@test_util.run_v1_only(\"Currently incompatible with ResourceVariable\")\nclass EvaluationMetricsTests(tf.test.TestCase):\n\n  def test_metrics_consistent(self):\n    # Tests that the identity metrics used to report in-sample predictions match\n    # the behavior of standard metrics.\n    g = tf.Graph()\n    with g.as_default():\n      features = {\n          feature_keys.TrainEvalFeatures.TIMES:\n              tf.zeros((1, 1)),\n          feature_keys.TrainEvalFeatures.VALUES:\n              tf.zeros((1, 1, 1)),\n          \"ticker\":\n              tf.reshape(\n                  tf.cast(\n                      tf.compat.v1.Variable(\n                          name=\"ticker\",\n                          initial_value=0,\n                          dtype=tf.dtypes.int64,\n                          collections=[tf.compat.v1.GraphKeys.LOCAL_VARIABLES])\n                      .count_up_to(10),\n                      dtype=tf.dtypes.float32), (1, 1, 1))\n      }\n      model_fn = ts_head_lib.TimeSeriesRegressionHead(\n          model=_TickerModel(),\n          state_manager=state_management.PassthroughStateManager(),\n          optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n              0.001)).create_estimator_spec\n      outputs = model_fn(\n          features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL)\n      metric_update_ops = [\n          metric[1] for metric in outputs.eval_metric_ops.values()\n      ]\n      loss_mean, loss_update = tf.compat.v1.metrics.mean(outputs.loss)\n      metric_update_ops.append(loss_update)\n      with self.cached_session() as sess:\n        coordinator = tf.train.Coordinator()\n        tf.compat.v1.train.queue_runner.start_queue_runners(\n            sess, coord=coordinator)\n        tf.compat.v1.initializers.local_variables().run()\n        sess.run(metric_update_ops)\n        loss_evaled, metric_evaled, nested_metric_evaled = sess.run(\n            (loss_mean, outputs.eval_metric_ops[\"ticker\"][0],\n             outputs.eval_metric_ops[\n                 feature_keys.FilteringResults.STATE_TUPLE][0][0]))\n        # The custom model_utils metrics for in-sample predictions should be in\n        # sync with the Estimator's mean metric for model loss.\n        self.assertAllClose(0., loss_evaled)\n        self.assertAllClose((((0.,),),), metric_evaled)\n        self.assertAllClose((((0.,),),), nested_metric_evaled)\n        coordinator.request_stop()\n        coordinator.join()\n\n  def test_custom_metrics(self):\n    \"\"\"Tests that the custom metrics can be applied to the estimator.\"\"\"\n    model_dir = self.get_temp_dir()\n    estimator = ts_estimators.LSTMAutoRegressor(\n        periodicities=1,\n        input_window_size=1,\n        output_window_size=1,\n        num_features=1,\n        num_units=4,\n        optimizer=tf.compat.v1.train.AdamOptimizer(0.001),\n        config=estimator_lib.RunConfig(tf_random_seed=4),\n        model_dir=model_dir)\n\n    def input_fn():\n      return {\n          feature_keys.TrainEvalFeatures.TIMES: [[1, 2, 3], [7, 8, 9]],\n          feature_keys.TrainEvalFeatures.VALUES:\n              numpy.array([[[0.], [1.], [0.]], [[2.], [3.], [2.]]])\n      }\n\n    def metrics_fn(predictions, features):\n      # checking that the inputs are properly passed.\n      predict = predictions[\"mean\"]\n      target = features[feature_keys.TrainEvalFeatures.VALUES][:, -1, 0]\n      return {\n          \"plain_boring_metric386\":\n              (tf.math.reduce_mean(tf.math.abs(predict - target)), tf.no_op()),\n          \"fun_metric101\": (tf.math.reduce_sum(predict + target), tf.no_op()),\n      }\n\n    # Evaluation without training is enough for testing custom metrics.\n    estimator = extenders.add_metrics(estimator, metrics_fn)\n    evaluation = estimator.evaluate(input_fn, steps=1)\n    self.assertIn(\"plain_boring_metric386\", evaluation)\n    self.assertIn(\"fun_metric101\", evaluation)\n    self.assertIn(\"average_loss\", evaluation)\n\n\nclass _StubModel(object):\n  num_features = 3\n  dtype = tf.dtypes.float64\n\n  def initialize_graph(self, input_statistics):\n    del input_statistics  # unused\n\n\ndef _stub_model_fn():\n  return ts_head_lib.TimeSeriesRegressionHead(\n      model=_StubModel(),\n      state_manager=state_management.PassthroughStateManager(),\n      optimizer=tf.compat.v1.train.AdamOptimizer(0.001)).create_estimator_spec\n\n\nclass TrainEvalFeatureCheckingTests(tf.test.TestCase):\n\n  def test_no_time_feature(self):\n    model_fn = _stub_model_fn()\n    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:\n      with self.assertRaisesRegexp(\n          ValueError, \"Expected a '{}' feature\".format(\n              feature_keys.TrainEvalFeatures.TIMES)):\n        model_fn(\n            features={feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]},\n            labels=None,\n            mode=mode)\n\n  def test_no_value_feature(self):\n    model_fn = _stub_model_fn()\n    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:\n      with self.assertRaisesRegexp(\n          ValueError, \"Expected a '{}' feature\".format(\n              feature_keys.TrainEvalFeatures.VALUES)):\n        model_fn(\n            features={feature_keys.TrainEvalFeatures.TIMES: [[1]]},\n            labels=None,\n            mode=mode)\n\n  def test_bad_time_rank(self):\n    model_fn = _stub_model_fn()\n    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:\n      with self.assertRaisesRegexp(\n          ValueError, \"Expected shape.*for feature '{}'\".format(\n              feature_keys.TrainEvalFeatures.TIMES)):\n        model_fn(\n            features={\n                feature_keys.TrainEvalFeatures.TIMES: [[[1]]],\n                feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]\n            },\n            labels=None,\n            mode=mode)\n\n  def test_bad_value_rank(self):\n    model_fn = _stub_model_fn()\n    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:\n      with self.assertRaisesRegexp(\n          ValueError, \"Expected shape.*for feature '{}'\".format(\n              feature_keys.TrainEvalFeatures.VALUES)):\n        model_fn(\n            features={\n                feature_keys.TrainEvalFeatures.TIMES: [[1]],\n                feature_keys.TrainEvalFeatures.VALUES: [[1.]]\n            },\n            labels=None,\n            mode=mode)\n\n  def test_bad_value_num_features(self):\n    model_fn = _stub_model_fn()\n    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:\n      with self.assertRaisesRegexp(\n          ValueError, \"Expected shape.*, 3.*for feature '{}'\".format(\n              feature_keys.TrainEvalFeatures.VALUES)):\n        model_fn(\n            features={\n                feature_keys.TrainEvalFeatures.TIMES: [[1]],\n                feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]\n            },\n            labels=None,\n            mode=mode)\n\n  def test_bad_exogenous_shape(self):\n    model_fn = _stub_model_fn()\n    for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]:\n      with self.assertRaisesRegexp(\n          ValueError, \"Features must have shape.*for feature 'exogenous'\"):\n        model_fn(\n            features={\n                feature_keys.TrainEvalFeatures.TIMES: [[1]],\n                feature_keys.TrainEvalFeatures.VALUES: [[[1., 2., 3.]]],\n                \"exogenous\": [[1], [2]]\n            },\n            labels=None,\n            mode=mode)\n\n\nclass PredictFeatureCheckingTests(tf.test.TestCase):\n\n  def test_no_time_feature(self):\n    model_fn = _stub_model_fn()\n    with self.assertRaisesRegexp(\n        ValueError, \"Expected a '{}' feature\".format(\n            feature_keys.PredictionFeatures.TIMES)):\n      model_fn(\n          features={\n              feature_keys.PredictionFeatures.STATE_TUPLE: ([[[1.]]], 1.)\n          },\n          labels=None,\n          mode=estimator_lib.ModeKeys.PREDICT)\n\n  def test_no_start_state_feature(self):\n    model_fn = _stub_model_fn()\n    with self.assertRaisesRegexp(\n        ValueError, \"Expected a '{}' feature\".format(\n            feature_keys.PredictionFeatures.STATE_TUPLE)):\n      model_fn(\n          features={feature_keys.PredictionFeatures.TIMES: [[1]]},\n          labels=None,\n          mode=estimator_lib.ModeKeys.PREDICT)\n\n  def test_bad_time_rank(self):\n    model_fn = _stub_model_fn()\n    with self.assertRaisesRegexp(\n        ValueError, \"Expected shape.*for feature '{}'\".format(\n            feature_keys.PredictionFeatures.TIMES)):\n      model_fn(\n          features={\n              feature_keys.PredictionFeatures.TIMES: 1,\n              feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.))\n          },\n          labels=None,\n          mode=estimator_lib.ModeKeys.PREDICT)\n\n  def test_bad_exogenous_shape(self):\n    model_fn = _stub_model_fn()\n    with self.assertRaisesRegexp(\n        ValueError, \"Features must have shape.*for feature 'exogenous'\"):\n      model_fn(\n          features={\n              feature_keys.PredictionFeatures.TIMES: [[1]],\n              feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)),\n              \"exogenous\": 1.\n          },\n          labels=None,\n          mode=estimator_lib.ModeKeys.PREDICT)\n\n\n@test_util.run_v1_only(\"Currently incompatible with ResourceVariable\")\nclass OneShotTests(parameterized.TestCase):\n\n  def test_one_shot_prediction_head_export(self):\n\n    def _new_temp_dir():\n      return os.path.join(tf.compat.v1.test.get_temp_dir(), str(ops.uid()))\n\n    model_dir = _new_temp_dir()\n    categorical_column = tf.feature_column.categorical_column_with_hash_bucket(\n        key=\"categorical_exogenous_feature\", hash_bucket_size=16)\n    exogenous_feature_columns = [\n        tf.feature_column.numeric_column(\"2d_exogenous_feature\", shape=(2,)),\n        tf.feature_column.embedding_column(\n            categorical_column=categorical_column, dimension=10)\n    ]\n    estimator = ts_estimators.TimeSeriesRegressor(\n        model=ar_model.ARModel(\n            periodicities=10,\n            input_window_size=10,\n            output_window_size=6,\n            num_features=5,\n            exogenous_feature_columns=exogenous_feature_columns,\n            prediction_model_factory=functools.partial(\n                ar_model.LSTMPredictionModel, num_units=10)),\n        head_type=ts_head_lib.OneShotPredictionHead,\n        model_dir=model_dir)\n\n    def train_input_fn():\n      num_range = tf.range(16, dtype=tf.dtypes.int64)\n      features = {\n          feature_keys.TrainEvalFeatures.TIMES:\n              tf.compat.v1.expand_dims(num_range, axis=0),\n          feature_keys.TrainEvalFeatures.VALUES:\n              tf.compat.v1.expand_dims(\n                  tf.tile(num_range[:, None], [1, 5]), axis=0),\n          \"2d_exogenous_feature\":\n              tf.ones([1, 16, 2]),\n          \"categorical_exogenous_feature\":\n              tf.compat.v1.expand_dims(\n                  tf.tile([\"strkey\"], [16])[:, None], axis=0)\n      }\n      return features\n\n    estimator.train(input_fn=train_input_fn, steps=5)\n    result = estimator.evaluate(input_fn=train_input_fn, steps=1)\n    self.assertIn(\"average_loss\", result)\n    self.assertNotIn(feature_keys.State.STATE_TUPLE, result)\n    input_receiver_fn = estimator.build_raw_serving_input_receiver_fn()\n    export_location = estimator.export_saved_model(_new_temp_dir(),\n                                                   input_receiver_fn)\n    graph = tf.Graph()\n    with graph.as_default():\n      with tf.compat.v1.Session() as session:\n        signatures = tf.compat.v1.saved_model.load(session,\n                                                   [tf.saved_model.SERVING],\n                                                   export_location)\n        self.assertEqual([feature_keys.SavedModelLabels.PREDICT],\n                         list(signatures.signature_def.keys()))\n        predict_signature = signatures.signature_def[\n            feature_keys.SavedModelLabels.PREDICT]\n        six.assertCountEqual(self, [\n            feature_keys.FilteringFeatures.TIMES,\n            feature_keys.FilteringFeatures.VALUES, \"2d_exogenous_feature\",\n            \"categorical_exogenous_feature\"\n        ], predict_signature.inputs.keys())\n        features = {\n            feature_keys.TrainEvalFeatures.TIMES:\n                numpy.tile(\n                    numpy.arange(35, dtype=numpy.int64)[None, :], [2, 1]),\n            feature_keys.TrainEvalFeatures.VALUES:\n                numpy.tile(\n                    numpy.arange(20, dtype=numpy.float32)[None, :, None],\n                    [2, 1, 5]),\n            \"2d_exogenous_feature\":\n                numpy.ones([2, 35, 2]),\n            \"categorical_exogenous_feature\":\n                numpy.tile(\n                    numpy.array([\"strkey\"] * 35)[None, :, None], [2, 1, 1])\n        }\n        feeds = {\n            graph.as_graph_element(input_value.name): features[input_key]\n            for input_key, input_value in predict_signature.inputs.items()\n        }\n        fetches = {\n            output_key: graph.as_graph_element(output_value.name)\n            for output_key, output_value in predict_signature.outputs.items()\n        }\n        output = session.run(fetches, feed_dict=feeds)\n        self.assertEqual((2, 15, 5), output[\"mean\"].shape)\n    # Build a parsing input function, then make a tf.Example for it to parse.\n    export_location = estimator.export_saved_model(\n        _new_temp_dir(),\n        estimator.build_one_shot_parsing_serving_input_receiver_fn(\n            filtering_length=20, prediction_length=15))\n    graph = tf.Graph()\n    with graph.as_default():\n      with tf.compat.v1.Session() as session:\n        example = example_pb2.Example()\n        times = example.features.feature[feature_keys.TrainEvalFeatures.TIMES]\n        values = example.features.feature[feature_keys.TrainEvalFeatures.VALUES]\n        times.int64_list.value.extend(range(35))\n        for i in range(20):\n          values.float_list.value.extend(\n              [float(i) * 2. + feature_number for feature_number in range(5)])\n        real_feature = example.features.feature[\"2d_exogenous_feature\"]\n        categortical_feature = example.features.feature[\n            \"categorical_exogenous_feature\"]\n        for i in range(35):\n          real_feature.float_list.value.extend([1, 1])\n          categortical_feature.bytes_list.value.append(b\"strkey\")\n        # Serialize the tf.Example for feeding to the Session\n        examples = [example.SerializeToString()] * 2\n        signatures = tf.compat.v1.saved_model.load(session,\n                                                   [tf.saved_model.SERVING],\n                                                   export_location)\n        predict_signature = signatures.signature_def[\n            feature_keys.SavedModelLabels.PREDICT]\n        ((_, input_value),) = predict_signature.inputs.items()\n        feeds = {graph.as_graph_element(input_value.name): examples}\n        fetches = {\n            output_key: graph.as_graph_element(output_value.name)\n            for output_key, output_value in predict_signature.outputs.items()\n        }\n        output = session.run(fetches, feed_dict=feeds)\n        self.assertEqual((2, 15, 5), output[\"mean\"].shape)\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/math_utils.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Miscellaneous utilities used by time series models.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport tensorflow as tf\nfrom tensorflow.python.ops import gen_math_ops\nfrom tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import TrainEvalFeatures\n\n\ndef replicate_state(start_state, batch_size):\n  \"\"\"Create batch versions of state.\n\n  Takes a list of Tensors, adds a batch dimension, and replicates\n  batch_size times across that batch dimension. Used to replicate the\n  non-batch state returned by get_start_state in define_loss.\n\n  Args:\n    start_state: Model-defined state to replicate.\n    batch_size: Batch dimension for data.\n\n  Returns:\n    Replicated versions of the state.\n  \"\"\"\n  flattened_state = tf.nest.flatten(start_state)\n  replicated_state = [\n      tf.tile(\n          tf.compat.v1.expand_dims(state_nonbatch, 0),\n          tf.concat([[batch_size],\n                     tf.ones([tf.rank(state_nonbatch)], dtype=tf.dtypes.int32)],\n                    0)) for state_nonbatch in flattened_state\n  ]\n  return tf.nest.pack_sequence_as(start_state, replicated_state)\n\n\nMoments = collections.namedtuple(\"Moments\", [\"mean\", \"variance\"])\n\n# Currently all of these statistics are computed incrementally (i.e. are updated\n# every time a new mini-batch of training data is presented) when this object is\n# created in InputStatisticsFromMiniBatch.\nInputStatistics = collections.namedtuple(\n    \"InputStatistics\",\n    [\n        # The mean and variance of each feature in a chunk (with a size\n        # configured in the statistics object) at the start of the series. A\n        # tuple of (mean, variance), each with shape [number of features],\n        # floating point. One use is in state space models, to keep priors\n        # calibrated even as earlier parts of the series are presented. If this\n        # object was created by InputStatisticsFromMiniBatch, these moments are\n        # computed based on the earliest chunk of data presented so far.\n        # However, there is a race condition in the update, so these may reflect\n        # statistics later in the series, but should eventually reflect\n        # statistics in a chunk at the series start.\n        \"series_start_moments\",\n        # The mean and variance of each feature over the entire series. A tuple\n        # of (mean, variance), each with shape [number of features]. If this\n        # object was created by InputStatisticsFromMiniBatch, these moments are\n        # estimates based on the data seen so far.\n        \"overall_feature_moments\",\n        # The first (lowest) time in the series, a scalar integer. If this\n        # object was created by InputStatisticsFromMiniBatch, this is the lowest\n        # time seen so far rather than the lowest time that will ever be seen\n        # (guaranteed to be at least as low as the lowest time presented in the\n        # current minibatch).\n        \"start_time\",\n        # Count of data points, a scalar integer. If this object was created by\n        # InputStatisticsFromMiniBatch, this is an estimate of the total number\n        # of observations in the whole dataset computed based on the density of\n        # the series and the minimum and maximum times seen.\n        \"total_observation_count\",\n    ])\n\n\n# TODO(allenl): It would be nice to do something with full series statistics\n# when the user provides that.\nclass InputStatisticsFromMiniBatch(object):\n  \"\"\"Generate statistics from mini-batch input.\"\"\"\n\n  def __init__(self, num_features, dtype, starting_variance_window_size=16):\n    \"\"\"Configure the input statistics object.\n\n    Args:\n      num_features: Number of features for the time series\n      dtype: The floating point data type to use.\n      starting_variance_window_size: The number of datapoints to use when\n        computing the mean and variance at the start of the series.\n    \"\"\"\n    self._starting_variance_window_size = starting_variance_window_size\n    self._num_features = num_features\n    self._dtype = dtype\n\n  def initialize_graph(self, features, update_statistics=True):\n    \"\"\"Create any ops needed to provide input statistics.\n\n    Should be called before statistics are requested.\n\n    Args:\n      features: A dictionary, the output of a `TimeSeriesInputFn` (with keys\n        TrainEvalFeatures.TIMES and TrainEvalFeatures.VALUES).\n      update_statistics: Whether `features` should be used to update adaptive\n        statistics. Typically True for training and false for evaluation.\n\n    Returns:\n      An InputStatistics object composed of Variables, which will be updated\n      based on mini-batches of data if requested.\n    \"\"\"\n    if (TrainEvalFeatures.TIMES in features and\n        TrainEvalFeatures.VALUES in features):\n      times = features[TrainEvalFeatures.TIMES]\n      values = features[TrainEvalFeatures.VALUES]\n    else:\n      # times and values may not be available, for example during prediction. We\n      # still need to retrieve our variables so that they can be read from, even\n      # if we're not going to update them.\n      times = None\n      values = None\n    # Create/retrieve variables representing input statistics, initialized\n    # without data to avoid deadlocking if variables are initialized before\n    # queue runners are started.\n    with tf.compat.v1.variable_scope(\"input_statistics\", use_resource=True):\n      statistics = self._create_variable_statistics_object()\n    with tf.compat.v1.variable_scope(\n        \"input_statistics_auxiliary\", use_resource=True):\n      # Secondary statistics, necessary for the incremental computation of the\n      # primary statistics (e.g. counts and sums for computing a mean\n      # incrementally).\n      auxiliary_variables = self._AdaptiveInputAuxiliaryStatistics(\n          num_features=self._num_features, dtype=self._dtype)\n    if update_statistics and times is not None and values is not None:\n      # If we have times and values from mini-batch input, create update ops to\n      # take the new data into account.\n      assign_op = self._update_statistics_from_mini_batch(\n          statistics, auxiliary_variables, times, values)\n      with tf.control_dependencies([assign_op]):\n        stat_variables = tf.nest.pack_sequence_as(\n            statistics,\n            [tf.identity(tensor) for tensor in tf.nest.flatten(statistics)])\n        # Since start time updates have a race condition, ensure that the\n        # reported start time is at least as low as the lowest time in this\n        # mini-batch. The start time should converge on the correct value\n        # eventually even with the race condition, but for example state space\n        # models have an assertion which could fail without this\n        # post-processing.\n        min_time = tf.cast(tf.math.reduce_min(times), tf.dtypes.int64)\n        start_time = tf.math.minimum(stat_variables.start_time, min_time)\n        return stat_variables._replace(start_time=start_time)\n    else:\n      return statistics\n\n  class _AdaptiveInputAuxiliaryStatistics(\n      collections.namedtuple(\n          \"_AdaptiveInputAuxiliaryStatistics\",\n          [\n              # The maximum time seen (best effort if updated from multiple\n              # workers; see notes about race condition below).\n              \"max_time_seen\",\n              # The number of chunks seen.\n              \"chunk_count\",\n              # The sum across chunks of their \"time density\" (number of times\n              # per example).\n              \"inter_observation_duration_sum\",\n              # The number of examples seen (each example has a single time\n              # associated with it and one or more real-valued features).\n              \"example_count\",\n              # The sum of values for each feature. Shape [number of features].\n              \"overall_feature_sum\",\n              # The sum of squared values for each feature.\n              # Shape [number of features].\n              \"overall_feature_sum_of_squares\",\n          ])):\n    \"\"\"Extra statistics used to incrementally update InputStatistics.\"\"\"\n\n    def __new__(cls, num_features, dtype):\n      return super(\n          InputStatisticsFromMiniBatch  # pylint: disable=protected-access\n          ._AdaptiveInputAuxiliaryStatistics,\n          cls).__new__(\n              cls,\n              max_time_seen=tf.compat.v1.get_variable(\n                  name=\"max_time_seen\",\n                  initializer=tf.dtypes.int64.min,\n                  dtype=tf.dtypes.int64,\n                  trainable=False),\n              chunk_count=tf.compat.v1.get_variable(\n                  name=\"chunk_count\",\n                  initializer=tf.compat.v1.initializers.zeros(),\n                  shape=[],\n                  dtype=tf.dtypes.int64,\n                  trainable=False),\n              inter_observation_duration_sum=tf.compat.v1.get_variable(\n                  name=\"inter_observation_duration_sum\",\n                  initializer=tf.compat.v1.initializers.zeros(),\n                  shape=[],\n                  dtype=dtype,\n                  trainable=False),\n              example_count=tf.compat.v1.get_variable(\n                  name=\"example_count\",\n                  shape=[],\n                  dtype=tf.dtypes.int64,\n                  trainable=False),\n              overall_feature_sum=tf.compat.v1.get_variable(\n                  name=\"overall_feature_sum\",\n                  shape=[num_features],\n                  dtype=dtype,\n                  initializer=tf.compat.v1.initializers.zeros(),\n                  trainable=False),\n              overall_feature_sum_of_squares=tf.compat.v1.get_variable(\n                  name=\"overall_feature_sum_of_squares\",\n                  shape=[num_features],\n                  dtype=dtype,\n                  initializer=tf.compat.v1.initializers.zeros(),\n                  trainable=False))\n\n  def _update_statistics_from_mini_batch(self, statistics, auxiliary_variables,\n                                         times, values):\n    \"\"\"Given mini-batch input, update `statistics` and `auxiliary_variables`.\"\"\"\n    values = tf.cast(values, self._dtype)\n    # The density (measured in times per observation) that we see in each part\n    # of the mini-batch.\n    batch_inter_observation_duration = (\n        tf.cast(\n            tf.math.reduce_max(times, axis=1) -\n            tf.math.reduce_min(times, axis=1), self._dtype) /\n        tf.cast(tf.compat.v1.shape(times)[1] - 1, self._dtype))\n    # Co-locate updates with their variables to minimize race conditions when\n    # updating statistics.\n    with tf.compat.v1.device(auxiliary_variables.max_time_seen.device):\n      # There is a race condition if this value is being updated from multiple\n      # workers. However, it should eventually reach the correct value if the\n      # last chunk is presented enough times.\n      latest_time = tf.cast(tf.math.reduce_max(times), tf.dtypes.int64)\n      max_time_seen = tf.math.maximum(auxiliary_variables.max_time_seen,\n                                      latest_time)\n      max_time_seen_assign = tf.compat.v1.assign(\n          auxiliary_variables.max_time_seen, max_time_seen)\n    with tf.compat.v1.device(auxiliary_variables.chunk_count.device):\n      chunk_count_assign = tf.compat.v1.assign_add(\n          auxiliary_variables.chunk_count,\n          tf.compat.v1.shape(times, out_type=tf.dtypes.int64)[0])\n    with tf.compat.v1.device(\n        auxiliary_variables.inter_observation_duration_sum.device):\n      inter_observation_duration_assign = tf.compat.v1.assign_add(\n          auxiliary_variables.inter_observation_duration_sum,\n          tf.math.reduce_sum(batch_inter_observation_duration))\n    with tf.compat.v1.device(auxiliary_variables.example_count.device):\n      example_count_assign = tf.compat.v1.assign_add(\n          auxiliary_variables.example_count,\n          tf.compat.v1.size(times, out_type=tf.dtypes.int64))\n    # Note: These mean/variance updates assume that all points are equally\n    # likely, which is not true if _chunks_ are sampled uniformly from the space\n    # of all possible contiguous chunks, since points at the start and end of\n    # the series are then members of fewer chunks. For series which are much\n    # longer than the chunk size (the usual/expected case), this effect becomes\n    # irrelevant.\n    with tf.compat.v1.device(auxiliary_variables.overall_feature_sum.device):\n      overall_feature_sum_assign = tf.compat.v1.assign_add(\n          auxiliary_variables.overall_feature_sum,\n          tf.math.reduce_sum(values, axis=[0, 1]))\n    with tf.compat.v1.device(\n        auxiliary_variables.overall_feature_sum_of_squares.device):\n      overall_feature_sum_of_squares_assign = tf.compat.v1.assign_add(\n          auxiliary_variables.overall_feature_sum_of_squares,\n          tf.math.reduce_sum(values**2, axis=[0, 1]))\n    per_chunk_aux_updates = tf.group(max_time_seen_assign, chunk_count_assign,\n                                     inter_observation_duration_assign,\n                                     example_count_assign,\n                                     overall_feature_sum_assign,\n                                     overall_feature_sum_of_squares_assign)\n    with tf.control_dependencies([per_chunk_aux_updates]):\n      example_count_float = tf.cast(auxiliary_variables.example_count,\n                                    self._dtype)\n      new_feature_mean = (\n          auxiliary_variables.overall_feature_sum / example_count_float)\n      overall_feature_mean_update = tf.compat.v1.assign(\n          statistics.overall_feature_moments.mean, new_feature_mean)\n      overall_feature_var_update = tf.compat.v1.assign(\n          statistics.overall_feature_moments.variance,\n          # De-biased n / (n - 1) variance correction\n          example_count_float / (example_count_float - 1.) *\n          (auxiliary_variables.overall_feature_sum_of_squares /\n           example_count_float - new_feature_mean**2))\n      # TODO(b/35675805): Remove this cast\n      min_time_batch = tf.cast(\n          tf.compat.v1.math.argmin(times[:, 0]), tf.dtypes.int32)\n\n      def series_start_updates():\n        # If this is the lowest-time chunk that we have seen so far, update\n        # series start moments to reflect that. Note that these statistics are\n        # \"best effort\", as there are race conditions in the update (however,\n        # they should eventually converge if the start of the series is\n        # presented enough times).\n        mean, variance = tf.compat.v1.nn.moments(\n            values[min_time_batch, :self._starting_variance_window_size],\n            axes=[0])\n        return tf.group(\n            tf.compat.v1.assign(statistics.series_start_moments.mean, mean),\n            tf.compat.v1.assign(statistics.series_start_moments.variance,\n                                variance))\n\n      with tf.compat.v1.device(statistics.start_time.device):\n        series_start_update = tf.compat.v1.cond(\n            # Update moments whenever we even match the lowest time seen so far,\n            # to ensure that series start statistics are eventually updated to\n            # their correct values, despite race conditions (i.e. eventually\n            # statistics.start_time will reflect the global lowest time, and\n            # given that we will eventually update the series start moments to\n            # their correct values).\n            tf.math.less_equal(times[min_time_batch, 0],\n                               tf.cast(statistics.start_time, times.dtype)),\n            series_start_updates,\n            tf.no_op)\n        with tf.control_dependencies([series_start_update]):\n          # There is a race condition if this update is performed in parallel on\n          # multiple workers. Since models may be sensitive to being presented\n          # with times before the putative start time, the value of this\n          # variable is post-processed above to guarantee that each worker is\n          # presented with a start time which is at least as low as the lowest\n          # time in its current mini-batch.\n          min_time = tf.cast(tf.math.reduce_min(times), tf.dtypes.int64)\n          start_time = tf.math.minimum(statistics.start_time, min_time)\n          start_time_update = tf.compat.v1.assign(statistics.start_time,\n                                                  start_time)\n      inter_observation_duration_estimate = (\n          auxiliary_variables.inter_observation_duration_sum /\n          tf.cast(auxiliary_variables.chunk_count, self._dtype))\n      # Estimate the total number of observations as:\n      #   (end time - start time + 1) * average intra-chunk time density\n      total_observation_count_update = tf.compat.v1.assign(\n          statistics.total_observation_count,\n          tf.cast(\n              gen_math_ops.round(\n                  tf.cast(max_time_seen_assign - start_time_update + 1,\n                          self._dtype) / inter_observation_duration_estimate),\n              tf.dtypes.int64))\n      per_chunk_stat_updates = tf.group(overall_feature_mean_update,\n                                        overall_feature_var_update,\n                                        series_start_update, start_time_update,\n                                        total_observation_count_update)\n    return per_chunk_stat_updates\n\n  def _create_variable_statistics_object(self):\n    \"\"\"Creates non-trainable variables representing input statistics.\"\"\"\n    series_start_moments = Moments(\n        mean=tf.compat.v1.get_variable(\n            name=\"series_start_mean\",\n            shape=[self._num_features],\n            dtype=self._dtype,\n            initializer=tf.compat.v1.initializers.zeros(),\n            trainable=False),\n        variance=tf.compat.v1.get_variable(\n            name=\"series_start_variance\",\n            shape=[self._num_features],\n            dtype=self._dtype,\n            initializer=tf.compat.v1.initializers.ones(),\n            trainable=False))\n    overall_feature_moments = Moments(\n        mean=tf.compat.v1.get_variable(\n            name=\"overall_feature_mean\",\n            shape=[self._num_features],\n            dtype=self._dtype,\n            initializer=tf.compat.v1.initializers.zeros(),\n            trainable=False),\n        variance=tf.compat.v1.get_variable(\n            name=\"overall_feature_var\",\n            shape=[self._num_features],\n            dtype=self._dtype,\n            initializer=tf.compat.v1.initializers.ones(),\n            trainable=False))\n    start_time = tf.compat.v1.get_variable(\n        name=\"start_time\",\n        dtype=tf.dtypes.int64,\n        initializer=tf.dtypes.int64.max,\n        trainable=False)\n    total_observation_count = tf.compat.v1.get_variable(\n        name=\"total_observation_count\",\n        shape=[],\n        dtype=tf.dtypes.int64,\n        initializer=tf.compat.v1.initializers.ones(),\n        trainable=False)\n    return InputStatistics(\n        series_start_moments=series_start_moments,\n        overall_feature_moments=overall_feature_moments,\n        start_time=start_time,\n        total_observation_count=total_observation_count)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/math_utils_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for math_utils.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.canned.timeseries import math_utils\nfrom tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import TrainEvalFeatures\n\n\nclass InputStatisticsTests(tf.test.TestCase):\n\n  def _input_statistics_test_template(self,\n                                      stat_object,\n                                      num_features,\n                                      dtype,\n                                      warmup_iterations=0,\n                                      rtol=1e-6,\n                                      data_length=4):\n    graph = tf.Graph()\n    with graph.as_default():\n      data_length_range = tf.range(data_length, dtype=dtype)\n      num_features_range = tf.range(num_features, dtype=dtype)\n      times = 2 * data_length_range[None, :] - 3\n      values = (data_length_range[:, None] + num_features_range[None, :])[None,\n                                                                          ...]\n      features = {\n          TrainEvalFeatures.TIMES: times,\n          TrainEvalFeatures.VALUES: values,\n      }\n      statistics = stat_object.initialize_graph(features=features)\n      with self.session(graph=graph) as session:\n        tf.compat.v1.initializers.global_variables().run()\n        coordinator = tf.train.Coordinator()\n        tf.compat.v1.train.queue_runner.start_queue_runners(\n            session, coord=coordinator)\n        for _ in range(warmup_iterations):\n          # A control dependency should ensure that, for queue-based statistics,\n          # a use of any statistic is preceded by an update of all adaptive\n          # statistics.\n          self.evaluate(statistics.total_observation_count)\n        self.assertAllClose(\n            tf.range(num_features, dtype=dtype) +\n            tf.math.reduce_mean(data_length_range)[None],\n            self.evaluate(statistics.series_start_moments.mean),\n            rtol=rtol)\n        self.assertAllClose(\n            tf.tile(\n                tf.math.reduce_variance(data_length_range)[None],\n                [num_features]),\n            self.evaluate(statistics.series_start_moments.variance),\n            rtol=rtol)\n        self.assertAllClose(\n            tf.math.reduce_mean(values[0], axis=0),\n            self.evaluate(statistics.overall_feature_moments.mean),\n            rtol=rtol)\n        self.assertAllClose(\n            tf.math.reduce_variance(values[0], axis=0),\n            self.evaluate(statistics.overall_feature_moments.variance),\n            rtol=rtol)\n        self.assertAllClose(-3, self.evaluate(statistics.start_time), rtol=rtol)\n        self.assertAllClose(\n            data_length,\n            self.evaluate(statistics.total_observation_count),\n            rtol=rtol)\n        coordinator.request_stop()\n        coordinator.join()\n\n  def test_queue(self):\n    for dtype in [tf.dtypes.float32, tf.dtypes.float64]:\n      for num_features in [1, 2, 3]:\n        self._input_statistics_test_template(\n            math_utils.InputStatisticsFromMiniBatch(\n                num_features=num_features, dtype=dtype),\n            num_features=num_features,\n            dtype=dtype,\n            warmup_iterations=1000,\n            rtol=0.1)\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/model.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Base class for time series models.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\nimport collections\n\nimport six\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.canned.timeseries import math_utils\nfrom tensorflow_estimator.python.estimator.canned.timeseries.feature_keys import TrainEvalFeatures\n\nModelOutputs = collections.namedtuple(  # pylint: disable=invalid-name\n    typename=\"ModelOutputs\",\n    field_names=[\n        \"loss\",  # The scalar value to be minimized during training.\n        \"end_state\",  # A nested tuple specifying the model's state after\n                      # running on the specified data\n        \"predictions\",  # A dictionary of predictions, each with shape prefixed\n                        # by the shape of `prediction_times`.\n        \"prediction_times\"  # A [batch size x window size] integer Tensor\n                            # indicating times for which values in `predictions`\n                            # were computed.\n    ])\n\n\n@six.add_metaclass(abc.ABCMeta)\nclass TimeSeriesModel(object):\n  \"\"\"Base class for creating generative time series models.\"\"\"\n\n  def __init__(self,\n               num_features,\n               exogenous_feature_columns=None,\n               dtype=tf.dtypes.float32):\n    \"\"\"Constructor for generative models.\n\n    Args:\n      num_features: Number of features for the time series\n      exogenous_feature_columns: A list of `tf.feature_column`s (for example\n        `tf.feature_column.embedding_column`) corresponding to exogenous\n        features which provide extra information to the model but are not part\n        of the series to be predicted. Passed to\n        `tf.feature_column.input_layer`.\n      dtype: The floating point datatype to use.\n    \"\"\"\n    if exogenous_feature_columns:\n      self._exogenous_feature_columns = exogenous_feature_columns\n    else:\n      self._exogenous_feature_columns = []\n    self.num_features = num_features\n    self.dtype = dtype\n    self._input_statistics = None\n    self._graph_initialized = False\n    self._stats_means = None\n    self._stats_sigmas = None\n\n  @property\n  def exogenous_feature_columns(self):\n    \"\"\"`tf.feature_colum`s for features which are not predicted.\"\"\"\n    return self._exogenous_feature_columns\n\n  # TODO(allenl): Move more of the generic machinery for generating and\n  # predicting into TimeSeriesModel, and possibly share it between generate()\n  # and predict()\n  def generate(self,\n               number_of_series,\n               series_length,\n               model_parameters=None,\n               seed=None):\n    \"\"\"Sample synthetic data from model parameters, with optional substitutions.\n\n    Returns `number_of_series` possible sequences of future values, sampled from\n    the generative model with each conditioned on the previous. Samples are\n    based on trained parameters, except for those parameters explicitly\n    overridden in `model_parameters`.\n\n    For distributions over future observations, see predict().\n\n    Args:\n      number_of_series: Number of time series to create.\n      series_length: Length of each time series.\n      model_parameters: A dictionary mapping model parameters to values, which\n        replace trained parameters when generating data.\n      seed: If specified, return deterministic time series according to this\n        value.\n\n    Returns:\n      A dictionary with keys TrainEvalFeatures.TIMES (mapping to an array with\n      shape [number_of_series, series_length]) and TrainEvalFeatures.VALUES\n      (mapping to an array with shape [number_of_series, series_length,\n      num_features]).\n    \"\"\"\n    raise NotImplementedError(\"This model does not support generation.\")\n\n  def initialize_graph(self, input_statistics=None):\n    \"\"\"Define ops for the model, not depending on any previously defined ops.\n\n    Args:\n      input_statistics: A math_utils.InputStatistics object containing input\n        statistics. If None, data-independent defaults are used, which may\n        result in longer or unstable training.\n    \"\"\"\n    self._graph_initialized = True\n    self._input_statistics = input_statistics\n    if self._input_statistics:\n      self._stats_means, variances = (\n          self._input_statistics.overall_feature_moments)\n      self._stats_sigmas = tf.math.sqrt(variances)\n\n  def _scale_data(self, data):\n    \"\"\"Scale data according to stats (input scale -> model scale).\"\"\"\n    if self._input_statistics is not None:\n      return (data - self._stats_means) / self._stats_sigmas\n    else:\n      return data\n\n  def _scale_variance(self, variance):\n    \"\"\"Scale variances according to stats (input scale -> model scale).\"\"\"\n    if self._input_statistics is not None:\n      return variance / self._input_statistics.overall_feature_moments.variance\n    else:\n      return variance\n\n  def _scale_back_data(self, data):\n    \"\"\"Scale back data according to stats (model scale -> input scale).\"\"\"\n    if self._input_statistics is not None:\n      return (data * self._stats_sigmas) + self._stats_means\n    else:\n      return data\n\n  def _scale_back_variance(self, variance):\n    \"\"\"Scale back variances according to stats (model scale -> input scale).\"\"\"\n    if self._input_statistics is not None:\n      return variance * self._input_statistics.overall_feature_moments.variance\n    else:\n      return variance\n\n  def _check_graph_initialized(self):\n    if not self._graph_initialized:\n      raise ValueError(\n          \"TimeSeriesModels require initialize_graph() to be called before \"\n          \"use. This defines variables and ops in the default graph, and \"\n          \"allows Tensor-valued input statistics to be specified.\")\n\n  def define_loss(self, features, mode):\n    \"\"\"Default loss definition with state replicated across a batch.\n\n    Time series passed to this model have a batch dimension, and each series in\n    a batch can be operated on in parallel. This loss definition assumes that\n    each element of the batch represents an independent sample conditioned on\n    the same initial state (i.e. it is simply replicated across the batch). A\n    batch size of one provides sequential operations on a single time series.\n\n    More complex processing may operate instead on get_start_state() and\n    get_batch_loss() directly.\n\n    Args:\n      features: A dictionary (such as is produced by a chunker) with at minimum\n        the following key/value pairs (others corresponding to the\n        `exogenous_feature_columns` argument to `__init__` may be included\n        representing exogenous regressors):\n        TrainEvalFeatures.TIMES: A [batch size x window size] integer Tensor\n          with times for each observation. If there is no artificial chunking,\n          the window size is simply the length of the time series.\n        TrainEvalFeatures.VALUES: A [batch size x window size x num features]\n          Tensor with values for each observation.\n      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL). For INFER, see\n        predict().\n\n    Returns:\n      A ModelOutputs object.\n    \"\"\"\n    self._check_graph_initialized()\n    start_state = math_utils.replicate_state(\n        start_state=self.get_start_state(),\n        batch_size=tf.compat.v1.shape(features[TrainEvalFeatures.TIMES])[0])\n    return self.get_batch_loss(features=features, mode=mode, state=start_state)\n\n  # TODO(vitalyk,allenl): Better documentation surrounding options for chunking,\n  # references to papers, etc.\n  @abc.abstractmethod\n  def get_start_state(self):\n    \"\"\"Returns a tuple of state for the start of the time series.\n\n    For example, a mean and covariance. State should not have a batch\n    dimension, and will often be TensorFlow Variables to be learned along with\n    the rest of the model parameters.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def get_batch_loss(self, features, mode, state):\n    \"\"\"Return predictions, losses, and end state for a time series.\n\n    Args:\n      features: A dictionary with times, values, and (optionally) exogenous\n        regressors. See `define_loss`.\n      mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER).\n      state: Model-dependent state, each with size [batch size x ...]. The\n        number and type will typically be fixed by the model (for example a mean\n        and variance).\n\n    Returns:\n      A ModelOutputs object.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def predict(self, features):\n    \"\"\"Returns predictions of future observations given an initial state.\n\n    Computes distributions for future observations. For sampled draws from the\n    model where each is conditioned on the previous, see generate().\n\n    Args:\n      features: A dictionary with at minimum the following key/value pairs\n        (others corresponding to the `exogenous_feature_columns` argument to\n        `__init__` may be included representing exogenous regressors):\n        PredictionFeatures.TIMES: A [batch size x window size] Tensor with times\n          to make predictions for. Times must be increasing within each part of\n          the batch, and must be greater than the last time `state` was updated.\n        PredictionFeatures.STATE_TUPLE: Model-dependent state, each with size\n          [batch size x ...]. The number and type will typically be fixed by the\n          model (for example a mean and variance). Typically these will be the\n          end state returned by get_batch_loss, predicting beyond that data.\n\n    Returns:\n      A dictionary with model-dependent predictions corresponding to the\n      requested times. Keys indicate the type of prediction, and values have\n      shape [batch size x window size x ...]. For example state space models\n      return a \"predicted_mean\" and \"predicted_covariance\".\n    \"\"\"\n    pass\n\n  def _get_exogenous_embedding_shape(self):\n    \"\"\"Computes the shape of the vector returned by _process_exogenous_features.\n\n    Returns:\n      The shape as a list. Does not include a batch dimension.\n    \"\"\"\n    if not self._exogenous_feature_columns:\n      return (0,)\n    with tf.Graph().as_default():\n      parsed_features = (\n          tf.compat.v1.feature_column.make_parse_example_spec(\n              self._exogenous_feature_columns))\n      placeholder_features = tf.compat.v1.io.parse_example(\n          serialized=tf.compat.v1.placeholder(\n              shape=[None], dtype=tf.dtypes.string),\n          features=parsed_features)\n      embedded = tf.compat.v1.feature_column.input_layer(\n          features=placeholder_features,\n          feature_columns=self._exogenous_feature_columns)\n      return embedded.get_shape().as_list()[1:]\n\n  def _process_exogenous_features(self, times, features):\n    \"\"\"Create a single vector from exogenous features.\n\n    Args:\n      times: A [batch size, window size] vector of times for this batch,\n        primarily used to check the shape information of exogenous features.\n      features: A dictionary of exogenous features corresponding to the columns\n        in self._exogenous_feature_columns. Each value should have a shape\n        prefixed by [batch size, window size].\n\n    Returns:\n      A Tensor with shape [batch size, window size, exogenous dimension], where\n      the size of the exogenous dimension depends on the exogenous feature\n      columns passed to the model's constructor.\n    Raises:\n      ValueError: If an exogenous feature has an unknown rank.\n    \"\"\"\n    if self._exogenous_feature_columns:\n      exogenous_features_single_batch_dimension = {}\n      for name, tensor in features.items():\n        if tensor.get_shape().ndims is None:\n          # input_from_feature_columns does not support completely unknown\n          # feature shapes, so we save on a bit of logic and provide a better\n          # error message by checking that here.\n          raise ValueError(\n              (\"Features with unknown rank are not supported. Got shape {} for \"\n               \"feature {}.\").format(tensor.get_shape(), name))\n        tensor_shape_dynamic = tf.compat.v1.shape(tensor)\n        tensor = tf.reshape(\n            tensor,\n            tf.concat([[tensor_shape_dynamic[0] * tensor_shape_dynamic[1]],\n                       tensor_shape_dynamic[2:]],\n                      axis=0))\n        # Avoid shape warnings when embedding \"scalar\" exogenous features (those\n        # with only batch and window dimensions); input_from_feature_columns\n        # expects input ranks to match the embedded rank.\n        if tensor.get_shape().ndims == 1 and tensor.dtype != tf.dtypes.string:\n          exogenous_features_single_batch_dimension[name] = tensor[:, None]\n        else:\n          exogenous_features_single_batch_dimension[name] = tensor\n      embedded_exogenous_features_single_batch_dimension = (\n          tf.compat.v1.feature_column.input_layer(\n              features=exogenous_features_single_batch_dimension,\n              feature_columns=self._exogenous_feature_columns,\n              trainable=True))\n      exogenous_regressors = tf.reshape(\n          embedded_exogenous_features_single_batch_dimension,\n          tf.concat([\n              tf.compat.v1.shape(times),\n              tf.compat.v1.shape(\n                  embedded_exogenous_features_single_batch_dimension)[1:]\n          ],\n                    axis=0))\n      exogenous_regressors.set_shape(times.get_shape().concatenate(\n          embedded_exogenous_features_single_batch_dimension.get_shape()[1:]))\n      exogenous_regressors = tf.cast(exogenous_regressors, dtype=self.dtype)\n    else:\n      # Not having any exogenous features is a special case so that models can\n      # avoid superfluous updates, which may not be free of side effects due to\n      # bias terms in transformations.\n      exogenous_regressors = None\n    return exogenous_regressors\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/model_utils.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Helper functions for training and constructing time series Models.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.canned.timeseries import feature_keys\n\n\n# TODO(agarwal): Remove and replace with functionality from tf.slim\ndef fully_connected(inp,\n                    inp_size,\n                    layer_size,\n                    name,\n                    activation=tf.nn.relu,\n                    dtype=tf.dtypes.float32):\n  \"\"\"Helper method to create a fully connected hidden layer.\"\"\"\n  wt = tf.compat.v1.get_variable(\n      name=\"{}_weight\".format(name), shape=[inp_size, layer_size], dtype=dtype)\n  bias = tf.compat.v1.get_variable(\n      name=\"{}_bias\".format(name),\n      shape=[layer_size],\n      initializer=tf.compat.v1.initializers.zeros())\n  output = tf.compat.v1.nn.xw_plus_b(inp, wt, bias)\n  if activation is not None:\n    assert callable(activation)\n    output = activation(output)\n  return output\n\n\ndef canonicalize_times_or_steps_from_output(times, steps,\n                                            previous_model_output):\n  \"\"\"Canonicalizes either relative or absolute times, with error checking.\"\"\"\n  if steps is not None and times is not None:\n    raise ValueError(\"Only one of `steps` and `times` may be specified.\")\n  if steps is None and times is None:\n    raise ValueError(\"One of `steps` and `times` must be specified.\")\n  if times is not None:\n    times = numpy.array(times)\n    if len(times.shape) != 2:\n      times = times[None, ...]\n    if (previous_model_output[feature_keys.FilteringResults.TIMES].shape[0] !=\n        times.shape[0]):\n      raise ValueError(\n          (\"`times` must have a batch dimension matching\"\n           \" the previous model output (got a batch dimension of {} for `times`\"\n           \" and {} for the previous model output).\").format(\n               times.shape[0], previous_model_output[\n                   feature_keys.FilteringResults.TIMES].shape[0]))\n    if not (previous_model_output[feature_keys.FilteringResults.TIMES][:, -1] <\n            times[:, 0]).all():\n      raise ValueError(\"Prediction times must be after the corresponding \"\n                       \"previous model output.\")\n  if steps is not None:\n    predict_times = (\n        previous_model_output[feature_keys.FilteringResults.TIMES][:, -1:] + 1 +\n        numpy.arange(steps)[None, ...])\n  else:\n    predict_times = times\n  return predict_times\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/saved_model_utils.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Convenience functions for working with time series saved_models.\n\n@@predict_continuation\n@@cold_start_filter\n@@filter_continuation\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy\n\nfrom tensorflow.python.util.all_util import remove_undocumented\nfrom tensorflow_estimator.python.estimator.canned.timeseries import feature_keys as _feature_keys\nfrom tensorflow_estimator.python.estimator.canned.timeseries import head as _head\nfrom tensorflow_estimator.python.estimator.canned.timeseries import model_utils as _model_utils\n\n\ndef _canonicalize_numpy_data(data, require_single_batch):\n  \"\"\"Do basic checking and reshaping for Numpy data.\n\n  Args:\n    data: A dictionary mapping keys to Numpy arrays, with several possible\n      shapes (requires keys `TrainEvalFeatures.TIMES` and\n      `TrainEvalFeatures.VALUES`): Single example; `TIMES` is a scalar and\n        `VALUES` is either a scalar or a vector of length [number of features].\n        Sequence; `TIMES` is a vector of shape [series length], `VALUES` either\n        has shape [series length] (univariate) or [series length x number of\n        features] (multivariate). Batch of sequences; `TIMES` is a vector of\n        shape [batch size x series length], `VALUES` has shape [batch size x\n        series length] or [batch size x series length x number of features]. In\n        any case, `VALUES` and any exogenous features must have their shapes\n        prefixed by the shape of the value corresponding to the `TIMES` key.\n    require_single_batch: If True, raises an error if the provided data has a\n      batch dimension > 1.\n\n  Returns:\n    A dictionary with features normalized to have shapes prefixed with [batch\n    size x series length]. The sizes of dimensions which were omitted in the\n    inputs are 1.\n  Raises:\n    ValueError: If dimensions are incorrect or do not match, or required\n      features are missing.\n  \"\"\"\n  features = {key: numpy.array(value) for key, value in data.items()}\n  if (_feature_keys.TrainEvalFeatures.TIMES not in features or\n      _feature_keys.TrainEvalFeatures.VALUES not in features):\n    raise ValueError(\"{} and {} are required features.\".format(\n        _feature_keys.TrainEvalFeatures.TIMES,\n        _feature_keys.TrainEvalFeatures.VALUES))\n  times = features[_feature_keys.TrainEvalFeatures.TIMES]\n  for key, value in features.items():\n    if value.shape[:len(times.shape)] != times.shape:\n      raise ValueError(\n          (\"All features must have their shapes prefixed by the shape of the\"\n           \" times feature. Got shape {} for feature '{}', but shape {} for\"\n           \" '{}'\").format(value.shape, key, times.shape,\n                           _feature_keys.TrainEvalFeatures.TIMES))\n  if not times.shape:  # a single example\n    if not features[_feature_keys.TrainEvalFeatures.VALUES].shape:  # univariate\n      # Add a feature dimension (with one feature)\n      features[_feature_keys.TrainEvalFeatures.VALUES] = features[\n          _feature_keys.TrainEvalFeatures.VALUES][..., None]\n    elif len(features[_feature_keys.TrainEvalFeatures.VALUES].shape) > 1:\n      raise ValueError(\n          (\"Got an unexpected number of dimensions for the '{}' feature.\"\n           \" Was expecting at most 1 dimension\"\n           \" ([number of features]) since '{}' does not \"\n           \"have a batch or time dimension, but got shape {}\").format(\n               _feature_keys.TrainEvalFeatures.VALUES,\n               _feature_keys.TrainEvalFeatures.TIMES,\n               features[_feature_keys.TrainEvalFeatures.VALUES].shape))\n    # Add trivial batch and time dimensions for every feature\n    features = {key: value[None, None, ...] for key, value in features.items()}\n  if len(times.shape) == 1:  # shape [series length]\n    if len(features[_feature_keys.TrainEvalFeatures.VALUES].shape\n          ) == 1:  # shape [series length]\n      # Add a feature dimension (with one feature)\n      features[_feature_keys.TrainEvalFeatures.VALUES] = features[\n          _feature_keys.TrainEvalFeatures.VALUES][..., None]\n    elif len(features[_feature_keys.TrainEvalFeatures.VALUES].shape) > 2:\n      raise ValueError(\n          (\"Got an unexpected number of dimensions for the '{}' feature.\"\n           \" Was expecting at most 2 dimensions\"\n           \" ([series length, number of features]) since '{}' does not \"\n           \"have a batch dimension, but got shape {}\").format(\n               _feature_keys.TrainEvalFeatures.VALUES,\n               _feature_keys.TrainEvalFeatures.TIMES,\n               features[_feature_keys.TrainEvalFeatures.VALUES].shape))\n    # Add trivial batch dimensions for every feature\n    features = {key: value[None, ...] for key, value in features.items()}\n  elif len(features[_feature_keys.TrainEvalFeatures.TIMES].shape\n          ) != 2:  # shape [batch size, series length]\n    raise ValueError(\n        (\"Got an unexpected number of dimensions for times. Was expecting at \"\n         \"most two ([batch size, series length]), but got shape {}.\").format(\n             times.shape))\n  if require_single_batch:\n    # We don't expect input to be already batched; batching is done later\n    if features[_feature_keys.TrainEvalFeatures.TIMES].shape[0] != 1:\n      raise ValueError(\"Got batch input, was expecting unbatched input.\")\n  return features\n\n\ndef _colate_features_to_feeds_and_fetches(signature,\n                                          features,\n                                          graph,\n                                          continue_from=None):\n  \"\"\"Uses a saved model signature to construct feed and fetch dictionaries.\"\"\"\n  if continue_from is None:\n    state_values = {}\n  elif _feature_keys.FilteringResults.STATE_TUPLE in continue_from:\n    # We're continuing from an evaluation, so we need to unpack/flatten state.\n    state_values = _head.state_to_dictionary(\n        continue_from[_feature_keys.FilteringResults.STATE_TUPLE])\n  else:\n    state_values = continue_from\n  input_feed_tensors_by_name = {\n      input_key: graph.as_graph_element(input_value.name)\n      for input_key, input_value in signature.inputs.items()\n  }\n  output_tensors_by_name = {\n      output_key: graph.as_graph_element(output_value.name)\n      for output_key, output_value in signature.outputs.items()\n  }\n  feed_dict = {}\n  for state_key, state_value in state_values.items():\n    feed_dict[input_feed_tensors_by_name[state_key]] = state_value\n  for feature_key, feature_value in features.items():\n    feed_dict[input_feed_tensors_by_name[feature_key]] = feature_value\n  return output_tensors_by_name, feed_dict\n\n\ndef predict_continuation(continue_from,\n                         signatures,\n                         session,\n                         steps=None,\n                         times=None,\n                         exogenous_features=None):\n  \"\"\"Perform prediction using an exported saved model.\n\n  Args:\n    continue_from: A dictionary containing the results of either an Estimator's\n      evaluate method or filter_continuation. Used to determine the model state\n      to make predictions starting from.\n    signatures: The `MetaGraphDef` protocol buffer returned from\n      `tf.saved_model.loader.load`. Used to determine the names of Tensors to\n      feed and fetch. Must be from the same model as `continue_from`.\n    session: The session to use. The session's graph must be the one into which\n      `tf.saved_model.loader.load` loaded the model.\n    steps: The number of steps to predict (scalar), starting after the\n      evaluation or filtering. If `times` is specified, `steps` must not be; one\n      is required.\n    times: A [batch_size x window_size] array of integers (not a Tensor)\n      indicating times to make predictions for. These times must be after the\n      corresponding evaluation or filtering. If `steps` is specified, `times`\n      must not be; one is required. If the batch dimension is omitted, it is\n      assumed to be 1.\n    exogenous_features: Optional dictionary. If specified, indicates exogenous\n      features for the model to use while making the predictions. Values must\n      have shape [batch_size x window_size x ...], where `batch_size` matches\n      the batch dimension used when creating `continue_from`, and `window_size`\n      is either the `steps` argument or the `window_size` of the `times`\n      argument (depending on which was specified).\n\n  Returns:\n    A dictionary with model-specific predictions (typically having keys \"mean\"\n    and \"covariance\") and a _feature_keys.PredictionResults.TIMES key indicating\n    the times for which the predictions were computed.\n  Raises:\n    ValueError: If `times` or `steps` are misspecified.\n  \"\"\"\n  if exogenous_features is None:\n    exogenous_features = {}\n  predict_times = _model_utils.canonicalize_times_or_steps_from_output(\n      times=times, steps=steps, previous_model_output=continue_from)\n  features = {_feature_keys.PredictionFeatures.TIMES: predict_times}\n  features.update(exogenous_features)\n  predict_signature = signatures.signature_def[\n      _feature_keys.SavedModelLabels.PREDICT]\n  output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(\n      continue_from=continue_from,\n      signature=predict_signature,\n      features=features,\n      graph=session.graph)\n  output = session.run(output_tensors_by_name, feed_dict=feed_dict)\n  output[_feature_keys.PredictionResults.TIMES] = features[\n      _feature_keys.PredictionFeatures.TIMES]\n  return output\n\n\ndef cold_start_filter(signatures, session, features):\n  \"\"\"Perform filtering using an exported saved model.\n\n  Filtering refers to updating model state based on new observations.\n  Predictions based on the returned model state will be conditioned on these\n  observations.\n\n  Starts from the model's default/uninformed state.\n\n  Args:\n    signatures: The `MetaGraphDef` protocol buffer returned from\n      `tf.saved_model.loader.load`. Used to determine the names of Tensors to\n      feed and fetch. Must be from the same model as `continue_from`.\n    session: The session to use. The session's graph must be the one into which\n      `tf.saved_model.loader.load` loaded the model.\n    features: A dictionary mapping keys to Numpy arrays, with several possible\n      shapes (requires keys `FilteringFeatures.TIMES` and\n      `FilteringFeatures.VALUES`): Single example; `TIMES` is a scalar and\n        `VALUES` is either a scalar or a vector of length [number of features].\n        Sequence; `TIMES` is a vector of shape [series length], `VALUES` either\n        has shape [series length] (univariate) or [series length x number of\n        features] (multivariate). Batch of sequences; `TIMES` is a vector of\n        shape [batch size x series length], `VALUES` has shape [batch size x\n        series length] or [batch size x series length x number of features]. In\n        any case, `VALUES` and any exogenous features must have their shapes\n        prefixed by the shape of the value corresponding to the `TIMES` key.\n\n  Returns:\n    A dictionary containing model state updated to account for the observations\n    in `features`.\n  \"\"\"\n  filter_signature = signatures.signature_def[\n      _feature_keys.SavedModelLabels.COLD_START_FILTER]\n  features = _canonicalize_numpy_data(data=features, require_single_batch=False)\n  output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(\n      signature=filter_signature, features=features, graph=session.graph)\n  output = session.run(output_tensors_by_name, feed_dict=feed_dict)\n  # Make it easier to chain filter -> predict by keeping track of the current\n  # time.\n  output[_feature_keys.FilteringResults.TIMES] = features[\n      _feature_keys.FilteringFeatures.TIMES]\n  return output\n\n\ndef filter_continuation(continue_from, signatures, session, features):\n  \"\"\"Perform filtering using an exported saved model.\n\n  Filtering refers to updating model state based on new observations.\n  Predictions based on the returned model state will be conditioned on these\n  observations.\n\n  Args:\n    continue_from: A dictionary containing the results of either an Estimator's\n      evaluate method or a previous filter step (cold start or continuation).\n      Used to determine the model state to start filtering from.\n    signatures: The `MetaGraphDef` protocol buffer returned from\n      `tf.saved_model.loader.load`. Used to determine the names of Tensors to\n      feed and fetch. Must be from the same model as `continue_from`.\n    session: The session to use. The session's graph must be the one into which\n      `tf.saved_model.loader.load` loaded the model.\n    features: A dictionary mapping keys to Numpy arrays, with several possible\n      shapes (requires keys `FilteringFeatures.TIMES` and\n      `FilteringFeatures.VALUES`): Single example; `TIMES` is a scalar and\n        `VALUES` is either a scalar or a vector of length [number of features].\n        Sequence; `TIMES` is a vector of shape [series length], `VALUES` either\n        has shape [series length] (univariate) or [series length x number of\n        features] (multivariate). Batch of sequences; `TIMES` is a vector of\n        shape [batch size x series length], `VALUES` has shape [batch size x\n        series length] or [batch size x series length x number of features]. In\n        any case, `VALUES` and any exogenous features must have their shapes\n        prefixed by the shape of the value corresponding to the `TIMES` key.\n\n  Returns:\n    A dictionary containing model state updated to account for the observations\n    in `features`.\n  \"\"\"\n  filter_signature = signatures.signature_def[\n      _feature_keys.SavedModelLabels.FILTER]\n  features = _canonicalize_numpy_data(data=features, require_single_batch=False)\n  output_tensors_by_name, feed_dict = _colate_features_to_feeds_and_fetches(\n      continue_from=continue_from,\n      signature=filter_signature,\n      features=features,\n      graph=session.graph)\n  output = session.run(output_tensors_by_name, feed_dict=feed_dict)\n  # Make it easier to chain filter -> predict by keeping track of the current\n  # time.\n  output[_feature_keys.FilteringResults.TIMES] = features[\n      _feature_keys.FilteringFeatures.TIMES]\n  return output\n\n\nremove_undocumented(module_name=__name__)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/timeseries/state_management.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Classes for wrapping a model to operate on different data shapes.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\nfrom tensorflow_estimator.python.estimator import estimator_lib\nfrom tensorflow_estimator.python.estimator.canned.timeseries import feature_keys\n\n\nclass PassthroughStateManager(object):\n  \"\"\"A minimal wrapper for models which do not need state management.\"\"\"\n\n  def __init__(self):\n    self._input_statistics = None\n    self._graph_initialized = False\n\n  def initialize_graph(self, model, input_statistics=None):\n    \"\"\"Adds required operations to the graph.\"\"\"\n    del model  # unused\n    self._graph_initialized = True\n    self._input_statistics = input_statistics\n\n  def define_loss(self, model, features, mode):\n    \"\"\"Wrap \"model\" with StateManager-specific operations.\n\n    Args:\n      model: The model (inheriting from TimeSeriesModel) to manage state for.\n      features: A dictionary with the following key/value pairs:\n        feature_keys.TrainEvalFeatures.TIMES: A [batch size x window size]\n          Tensor with times for each observation.\n        feature_keys.TrainEvalFeatures.VALUES: A [batch size x window size x num\n          features] Tensor with values for each observation.\n      mode: The tf.estimator.ModeKeys mode to use (TRAIN or EVAL).\n\n    Returns:\n      A ModelOutputs object.\n    Raises:\n      ValueError: If start state was specified.\n    \"\"\"\n    if feature_keys.State.STATE_TUPLE in features:\n      raise ValueError(\n          \"Overriding start state is not supported for this model.\")\n    return model.define_loss(features, mode)\n\n\nclass _OverridableStateManager(PassthroughStateManager):\n  \"\"\"Base class for state managers which support overriding model state.\"\"\"\n\n  @abc.abstractmethod\n  def _define_loss_with_saved_state(self, model, features, mode):\n    pass\n\n  def define_loss(self, model, features, mode):\n    \"\"\"Switches between explicit start state and managed state.\"\"\"\n    if feature_keys.FilteringFeatures.STATE_TUPLE in features:\n      # Explicit start state has been provided, so we should use that.\n      if mode == estimator_lib.ModeKeys.TRAIN:\n        raise ValueError(\n            \"Overriding saved state for training is not supported (but a value \"\n            \"for feature {} was specified).\".format(\n                feature_keys.FilteringFeatures.STATE_TUPLE))\n      start_state = features[feature_keys.FilteringFeatures.STATE_TUPLE]\n      del features[feature_keys.FilteringFeatures.STATE_TUPLE]\n      return model.get_batch_loss(\n          features=features, mode=mode, state=start_state)\n    else:\n      # No explicit start state; use managed state.\n      return self._define_loss_with_saved_state(\n          model=model, features=features, mode=mode)\n\n\nclass FilteringOnlyStateManager(_OverridableStateManager):\n  \"\"\"State manager for models which use state only for filtering.\n\n  Window-based models (ARModel) do not require state to be fed during training\n  (instead requiring a specific window size). Rather than requiring a minimum\n  window size for filtering, these models maintain this window in their state,\n  and so need state to be fed.\n  \"\"\"\n\n  def _define_loss_with_saved_state(self, model, features, mode):\n    return model.define_loss(features, mode)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/__init__.py",
    "content": ""
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/baseline_estimator_test_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for BaselineEstimatorV1.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.canned import baseline\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n# Names of variables created by model.\nBIAS_NAME = 'baseline/bias'\n\n\ndef assert_close(expected, actual, rtol=1e-04, name='assert_close'):\n  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:\n    expected = ops.convert_to_tensor(expected, name='expected')\n    actual = ops.convert_to_tensor(actual, name='actual')\n    rdiff = tf.math.abs(expected - actual, 'diff') / tf.math.abs(expected)\n    rtol = ops.convert_to_tensor(rtol, name='rtol')\n    return tf.compat.v1.debugging.assert_less(\n        rdiff,\n        rtol,\n        data=('Condition expected =~ actual did not hold element-wise:'\n              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,\n              'rtol = ', rtol,),\n        name=scope)\n\n\ndef save_variables_to_ckpt(model_dir):\n  init_all_op = [tf.compat.v1.initializers.global_variables()]\n  with tf.compat.v1.Session() as sess:\n    sess.run(init_all_op)\n    tf.compat.v1.train.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))\n\n\ndef _baseline_estimator_fn(weight_column=None, label_dimension=1, **kwargs):\n  return baseline.BaselineEstimator(\n      head=head_lib._regression_head(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          loss_reduction=tf.compat.v1.losses.Reduction.SUM),\n      **kwargs)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineEstimatorEvaluationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_evaluation_batch(self):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir)\n    eval_metrics = baseline_estimator.evaluate(\n        input_fn=lambda: ({\n            'age': ((1,), (1,))\n        }, ((10.,), (10.,))), steps=1)\n\n    # Logit is bias = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the sum over batch = 9 + 9 = 18\n    # Average loss is the average over batch = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 18.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_weights(self):\n    \"\"\"Tests evaluation with weights.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))}\n      labels = ((10.,), (10.,))\n      return features, labels\n\n    baseline_estimator = _baseline_estimator_fn(\n        weight_column='weights', model_dir=self._model_dir)\n    eval_metrics = baseline_estimator.evaluate(input_fn=_input_fn, steps=1)\n\n    # Logit is bias = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the weighted sum over batch = 9 + 2*9 = 27\n    # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 27.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_for_multi_dimensions(self):\n    label_dim = 2\n    with tf.Graph().as_default():\n      tf.Variable([46.0, 58.0], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_estimator = _baseline_estimator_fn(\n        label_dimension=label_dim, model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': np.array([[2., 4., 5.]]),\n        },\n        y=np.array([[46., 58.]]),\n        batch_size=1,\n        num_epochs=None,\n        shuffle=False)\n    eval_metrics = baseline_estimator.evaluate(input_fn=input_fn, steps=1)\n\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is bias which is [46, 58]\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineEstimatorPredictTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_1d(self):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_estimator = _baseline_estimator_fn(model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[2.]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = baseline_estimator.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x * weight + bias = 2. * 10. + .2 = 20.2\n    self.assertAllClose([[.2]], predicted_scores)\n\n  def testMultiDim(self):\n    \"\"\"Tests predict when all variables are multi-dimenstional.\"\"\"\n    batch_size = 2\n    label_dimension = 3\n    with tf.Graph().as_default():\n      tf.Variable(  # shape=[label_dimension]\n          [.2, .4, .6], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_estimator = _baseline_estimator_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        # x shape=[batch_size, x_dim]\n        x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predictions = baseline_estimator.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # score = bias, shape=[batch_size, label_dimension]\n    self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], predicted_scores)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineEstimatorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, prediction_length):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = _baseline_estimator_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['predictions'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, label_dimension), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineEstimatorTrainingTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _mock_optimizer(self, expected_loss=None):\n    expected_var_names = ['%s:0' % BIAS_NAME]\n\n    def _minimize(loss, global_step=None, var_list=None):\n      trainable_vars = var_list or tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      self.assertItemsEqual(expected_var_names,\n                            [var.name for var in trainable_vars])\n\n      # Verify loss. We can't check the value directly, so we add an assert op.\n      self.assertEquals(0, loss.shape.ndims)\n      if expected_loss is None:\n        if global_step is not None:\n          return tf.compat.v1.assign_add(global_step, 1).op\n        return tf.no_op()\n      assert_loss = assert_close(\n          tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n          loss,\n          name='assert_loss')\n      with tf.control_dependencies((assert_loss,)):\n        if global_step is not None:\n          return tf.compat.v1.assign_add(global_step, 1).op\n        return tf.no_op()\n\n    mock_optimizer = tf.compat.v1.test.mock.NonCallableMock(\n        spec=tf.compat.v1.train.Optimizer,\n        wraps=tf.compat.v1.train.Optimizer(\n            use_locking=False, name='my_optimizer'))\n    mock_optimizer.minimize = tf.compat.v1.test.mock.MagicMock(wraps=_minimize)\n\n    # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.\n    # So, return mock_optimizer itself for deepcopy.\n    mock_optimizer.__deepcopy__ = lambda _: mock_optimizer\n    return mock_optimizer\n\n  def _assert_checkpoint(self,\n                         label_dimension,\n                         expected_global_step,\n                         expected_bias=None):\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([label_dimension], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertEqual(expected_bias,\n                       tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def testFromScratch(self):\n    # Create BaselineRegressor.\n    label = 5.\n    age = 17\n    # loss = (logits - label)^2 = (0 - 5.)^2 = 25.\n    mock_optimizer = self._mock_optimizer(expected_loss=25.)\n    baseline_estimator = _baseline_estimator_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_estimator.train(\n        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        label_dimension=1, expected_global_step=num_steps, expected_bias=[0.])\n\n  def testFromCheckpoint(self):\n    # Create initial checkpoint.\n    bias = 7.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = bias = 6.\n    # loss = (logits - label)^2 = (7 - 5)^2 = 4\n    mock_optimizer = self._mock_optimizer(expected_loss=4.)\n    baseline_estimator = _baseline_estimator_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_estimator.train(\n        input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        label_dimension=1,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=[bias])\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/baseline_test_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for v1 version of baseline.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\nimport os\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.canned import baseline\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n# pylint rules which are disabled by default for test files.\n# pylint: disable=invalid-name,protected-access,missing-docstring\n\n# Names of variables created by model.\nBIAS_NAME = 'baseline/bias'\n\n\ndef assert_close(expected, actual, rtol=1e-04, name='assert_close'):\n  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:\n    expected = ops.convert_to_tensor(expected, name='expected')\n    actual = ops.convert_to_tensor(actual, name='actual')\n    rdiff = tf.math.abs(expected - actual, 'diff') / tf.math.abs(expected)\n    rtol = ops.convert_to_tensor(rtol, name='rtol')\n    return tf.compat.v1.debugging.assert_less(\n        rdiff,\n        rtol,\n        data=('Condition expected =~ actual did not hold element-wise:'\n              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,\n              'rtol = ', rtol,),\n        name=scope)\n\n\ndef save_variables_to_ckpt(model_dir):\n  init_all_op = [tf.compat.v1.initializers.global_variables()]\n  with tf.compat.v1.Session() as sess:\n    sess.run(init_all_op)\n    tf.compat.v1.train.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))\n\n\ndef queue_parsed_features(feature_map):\n  tensors_to_enqueue = []\n  keys = []\n  for key, tensor in six.iteritems(feature_map):\n    keys.append(key)\n    tensors_to_enqueue.append(tensor)\n  queue_dtypes = [x.dtype for x in tensors_to_enqueue]\n  input_queue = tf.queue.FIFOQueue(capacity=100, dtypes=queue_dtypes)\n  tf.compat.v1.train.queue_runner.add_queue_runner(\n      tf.compat.v1.train.queue_runner.QueueRunner(\n          input_queue, [input_queue.enqueue(tensors_to_enqueue)]))\n  dequeued_tensors = input_queue.dequeue()\n  return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}\n\n\ndef sorted_key_dict(unsorted_dict):\n  return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}\n\n\ndef sigmoid(x):\n  return 1 / (1 + np.exp(-1.0 * x))\n\n\ndef _baseline_regressor_fn(*args, **kwargs):\n  return baseline.BaselineRegressor(*args, **kwargs)\n\n\ndef _baseline_classifier_fn(*args, **kwargs):\n  return baseline.BaselineClassifier(*args, **kwargs)\n\n\n# Tests for Baseline Regressor.\n\n\n# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineRegressorEvaluationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_evaluation_for_simple_data(self):\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)\n    eval_metrics = baseline_regressor.evaluate(\n        input_fn=lambda: ({\n            'age': ((1,),)\n        }, ((10.,),)), steps=1)\n\n    # Logit is bias = 13, while label is 10. Loss is 3**2 = 9.\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 9.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_batch(self):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)\n    eval_metrics = baseline_regressor.evaluate(\n        input_fn=lambda: ({\n            'age': ((1,), (1,))\n        }, ((10.,), (10.,))), steps=1)\n\n    # Logit is bias = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the sum over batch = 9 + 9 = 18\n    # Average loss is the average over batch = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 18.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_weights(self):\n    \"\"\"Tests evaluation with weights.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([13.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))}\n      labels = ((10.,), (10.,))\n      return features, labels\n\n    baseline_regressor = _baseline_regressor_fn(\n        weight_column='weights', model_dir=self._model_dir)\n    eval_metrics = baseline_regressor.evaluate(input_fn=_input_fn, steps=1)\n\n    # Logit is bias = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the weighted sum over batch = 9 + 2*9 = 27\n    # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 27.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_for_multi_dimensions(self):\n    label_dim = 2\n    with tf.Graph().as_default():\n      tf.Variable([46.0, 58.0], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(\n        label_dimension=label_dim, model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': np.array([[2., 4., 5.]]),\n        },\n        y=np.array([[46., 58.]]),\n        batch_size=1,\n        num_epochs=None,\n        shuffle=False)\n    eval_metrics = baseline_regressor.evaluate(input_fn=input_fn, steps=1)\n\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is bias which is [46, 58]\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineRegressorPredictTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_1d(self):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[2.]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = baseline_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x * weight + bias = 2. * 10. + .2 = 20.2\n    self.assertAllClose([[.2]], predicted_scores)\n\n  def testMultiDim(self):\n    \"\"\"Tests predict when all variables are multi-dimenstional.\"\"\"\n    batch_size = 2\n    label_dimension = 3\n    with tf.Graph().as_default():\n      tf.Variable(  # shape=[label_dimension]\n          [.2, .4, .6], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    baseline_regressor = _baseline_regressor_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        # x shape=[batch_size, x_dim]\n        x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predictions = baseline_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # score = bias, shape=[batch_size, label_dimension]\n    self.assertAllClose([[0.2, 0.4, 0.6], [0.2, 0.4, 0.6]], predicted_scores)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineRegressorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, prediction_length):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = _baseline_regressor_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['predictions'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, label_dimension), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n\n    # Pandas DataFrame natually supports 1 dim data only.\n    label_dimension = 1\n    input_dimension = label_dimension\n    batch_size = 10\n    data = np.array([1., 2., 3., 4.], dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(data)\n    prediction_length = 4\n\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(\n                              value=datum[:label_dimension])),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineRegressorTrainingTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _mock_optimizer(self, expected_loss=None):\n    expected_var_names = ['%s:0' % BIAS_NAME]\n\n    def _minimize(loss, global_step=None, var_list=None):\n      trainable_vars = var_list or tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      self.assertItemsEqual(expected_var_names,\n                            [var.name for var in trainable_vars])\n\n      # Verify loss. We can't check the value directly, so we add an assert op.\n      self.assertEquals(0, loss.shape.ndims)\n      if expected_loss is None:\n        if global_step is not None:\n          return tf.compat.v1.assign_add(global_step, 1).op\n        return tf.no_op()\n      assert_loss = assert_close(\n          tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n          loss,\n          name='assert_loss')\n      with tf.control_dependencies((assert_loss,)):\n        if global_step is not None:\n          return tf.compat.v1.assign_add(global_step, 1).op\n        return tf.no_op()\n\n    mock_optimizer = tf.compat.v1.test.mock.NonCallableMock(\n        spec=tf.compat.v1.train.Optimizer,\n        wraps=tf.compat.v1.train.Optimizer(\n            use_locking=False, name='my_optimizer'))\n    mock_optimizer.minimize = tf.compat.v1.test.mock.MagicMock(wraps=_minimize)\n\n    # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.\n    # So, return mock_optimizer itself for deepcopy.\n    mock_optimizer.__deepcopy__ = lambda _: mock_optimizer\n    return mock_optimizer\n\n  def _assert_checkpoint(self,\n                         label_dimension,\n                         expected_global_step,\n                         expected_bias=None):\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([label_dimension], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertEqual(expected_bias,\n                       tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def testFromScratchWithDefaultOptimizer(self):\n    # Create BaselineRegressor.\n    label = 5.\n    age = 17\n    baseline_regressor = _baseline_regressor_fn(model_dir=self._model_dir)\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    baseline_regressor.train(\n        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)\n    self._assert_checkpoint(label_dimension=1, expected_global_step=num_steps)\n\n  def testTrainWithOneDimLabel(self):\n    label_dimension = 1\n    batch_size = 20\n    est = _baseline_regressor_fn(\n        label_dimension=label_dimension, model_dir=self._model_dir)\n    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)\n    self.assertEqual((batch_size,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(label_dimension=1, expected_global_step=200)\n\n  def testTrainWithOneDimWeight(self):\n    label_dimension = 1\n    batch_size = 20\n    est = _baseline_regressor_fn(\n        label_dimension=label_dimension,\n        weight_column='w',\n        model_dir=self._model_dir)\n\n    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)\n    self.assertEqual((batch_size,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_1\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(label_dimension=1, expected_global_step=200)\n\n  def testFromScratch(self):\n    # Create BaselineRegressor.\n    label = 5.\n    age = 17\n    # loss = (logits - label)^2 = (0 - 5.)^2 = 25.\n    mock_optimizer = self._mock_optimizer(expected_loss=25.)\n    baseline_regressor = _baseline_regressor_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_regressor.train(\n        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        label_dimension=1, expected_global_step=num_steps, expected_bias=[0.])\n\n  def testFromCheckpoint(self):\n    # Create initial checkpoint.\n    bias = 7.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = bias = 6.\n    # loss = (logits - label)^2 = (7 - 5)^2 = 4\n    mock_optimizer = self._mock_optimizer(expected_loss=4.)\n    baseline_regressor = _baseline_regressor_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_regressor.train(\n        input_fn=lambda: ({'age': ((17,),)}, ((5.,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        label_dimension=1,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=[bias])\n\n  def testFromCheckpointMultiBatch(self):\n    # Create initial checkpoint.\n    bias = 5.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = bias\n    # logits[0] = 5.\n    # logits[1] = 5.\n    # loss = sum(logits - label)^2 = (5 - 5)^2 + (5 - 3)^2 = 4\n    mock_optimizer = self._mock_optimizer(expected_loss=4.)\n    baseline_regressor = _baseline_regressor_fn(\n        model_dir=self._model_dir, optimizer=mock_optimizer)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    baseline_regressor.train(\n        input_fn=lambda: ({'age': ((17,), (15,))}, ((5.,), (3.,))),\n        steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        label_dimension=1,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=bias)\n\n\n# Tests for Baseline Classifier.\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineClassifierTrainingTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _mock_optimizer(self, expected_loss=None):\n    expected_var_names = ['%s:0' % BIAS_NAME]\n\n    def _minimize(loss, global_step):\n      trainable_vars = tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      self.assertItemsEqual(expected_var_names,\n                            [var.name for var in trainable_vars])\n\n      # Verify loss. We can't check the value directly, so we add an assert op.\n      self.assertEquals(0, loss.shape.ndims)\n      if expected_loss is None:\n        return tf.compat.v1.assign_add(global_step, 1).op\n      assert_loss = assert_close(\n          tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n          loss,\n          name='assert_loss')\n      with tf.control_dependencies((assert_loss,)):\n        return tf.compat.v1.assign_add(global_step, 1).op\n\n    mock_optimizer = tf.compat.v1.test.mock.NonCallableMock(\n        spec=tf.compat.v1.train.Optimizer,\n        wraps=tf.compat.v1.train.Optimizer(\n            use_locking=False, name='my_optimizer'))\n    mock_optimizer.minimize = tf.compat.v1.test.mock.MagicMock(wraps=_minimize)\n\n    # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.\n    # So, return mock_optimizer itself for deepcopy.\n    mock_optimizer.__deepcopy__ = lambda _: mock_optimizer\n    return mock_optimizer\n\n  def _assert_checkpoint(self,\n                         n_classes,\n                         expected_global_step,\n                         expected_bias=None):\n    logits_dimension = n_classes if n_classes > 2 else 1\n\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([logits_dimension], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertAllEqual(expected_bias,\n                          tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def _testFromScratchWithDefaultOptimizer(self, n_classes):\n    label = 0\n    age = 17\n    est = baseline.BaselineClassifier(\n        n_classes=n_classes, model_dir=self._model_dir)\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)\n    self._assert_checkpoint(n_classes, num_steps)\n\n  def testBinaryClassesFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=2)\n\n  def testMultiClassesFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=4)\n\n  def _testTrainWithTwoDimsLabel(self, n_classes):\n    batch_size = 20\n\n    est = baseline.BaselineClassifier(\n        n_classes=n_classes, model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    data_rank_2 = np.array([[0], [1]])\n    self.assertEqual((2,), data_rank_1.shape)\n    self.assertEqual((2, 1), data_rank_2.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_2,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithTwoDimsLabel(self):\n    self._testTrainWithTwoDimsLabel(n_classes=2)\n\n  def testMultiClassesTrainWithTwoDimsLabel(self):\n    self._testTrainWithTwoDimsLabel(n_classes=4)\n\n  def _testTrainWithOneDimLabel(self, n_classes):\n    batch_size = 20\n\n    est = baseline.BaselineClassifier(\n        n_classes=n_classes, model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    self.assertEqual((2,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithOneDimLabel(self):\n    self._testTrainWithOneDimLabel(n_classes=2)\n\n  def testMultiClassesTrainWithOneDimLabel(self):\n    self._testTrainWithOneDimLabel(n_classes=4)\n\n  def _testTrainWithTwoDimsWeight(self, n_classes):\n    batch_size = 20\n\n    est = baseline.BaselineClassifier(\n        weight_column='w', n_classes=n_classes, model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    data_rank_2 = np.array([[0], [1]])\n    self.assertEqual((2,), data_rank_1.shape)\n    self.assertEqual((2, 1), data_rank_2.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_2\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithTwoDimsWeight(self):\n    self._testTrainWithTwoDimsWeight(n_classes=2)\n\n  def testMultiClassesTrainWithTwoDimsWeight(self):\n    self._testTrainWithTwoDimsWeight(n_classes=4)\n\n  def _testTrainWithOneDimWeight(self, n_classes):\n    batch_size = 20\n\n    est = baseline.BaselineClassifier(\n        weight_column='w', n_classes=n_classes, model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    self.assertEqual((2,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_1\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithOneDimWeight(self):\n    self._testTrainWithOneDimWeight(n_classes=2)\n\n  def testMultiClassesTrainWithOneDimWeight(self):\n    self._testTrainWithOneDimWeight(n_classes=4)\n\n  def _testFromScratch(self, n_classes):\n    label = 1\n    age = 17\n    # For binary classifier:\n    #   loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are\n    #   all zero initially) and label = 1 so,\n    #      loss = 1 * -log ( sigmoid(logits) ) = 0.69315\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label) where logits are all 0s (weights are\n    #   all zero initially) and label = 1 so,\n    #      loss = 1 * -log ( 1.0 / n_classes )\n    # For this particular test case, as logits are same, the formula\n    # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases.\n    mock_optimizer = self._mock_optimizer(\n        expected_loss=(-1 * math.log(1.0 / n_classes)))\n\n    est = baseline.BaselineClassifier(\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=num_steps,\n        expected_bias=[0.] if n_classes == 2 else [.0] * n_classes)\n\n  def testBinaryClassesFromScratch(self):\n    self._testFromScratch(n_classes=2)\n\n  def testMultiClassesFromScratch(self):\n    self._testFromScratch(n_classes=4)\n\n  def _testFromCheckpoint(self, n_classes):\n    # Create initial checkpoint.\n    label = 1\n    age = 17\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # For binary classifier:\n    #   logits = bias = -1.\n    #   loss = sigmoid_cross_entropy(logits, label)\n    #   so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label)\n    #   where logits = bias and label = 1\n    #   so, loss = 1 * -log ( softmax(logits)[1] )\n    if n_classes == 2:\n      expected_loss = 1.3133\n    else:\n      logits = bias\n      logits_exp = np.exp(logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_loss = -1 * math.log(softmax[label])\n\n    mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)\n\n    est = baseline.BaselineClassifier(\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=bias)\n\n  def testBinaryClassesFromCheckpoint(self):\n    self._testFromCheckpoint(n_classes=2)\n\n  def testMultiClassesFromCheckpoint(self):\n    self._testFromCheckpoint(n_classes=4)\n\n  def _testFromCheckpointFloatLabels(self, n_classes):\n    \"\"\"Tests float labels for binary classification.\"\"\"\n    # Create initial checkpoint.\n    if n_classes > 2:\n      return\n    label = 0.8\n    age = 17\n    bias = [-1.0]\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = bias = -1.\n    # loss = sigmoid_cross_entropy(logits, label)\n    # => loss = -0.8 * log(sigmoid(-1)) -0.2 * log(sigmoid(+1)) = 1.1132617\n    mock_optimizer = self._mock_optimizer(expected_loss=1.1132617)\n\n    est = baseline.BaselineClassifier(\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n\n  def testBinaryClassesFromCheckpointFloatLabels(self):\n    self._testFromCheckpointFloatLabels(n_classes=2)\n\n  def testMultiClassesFromCheckpointFloatLabels(self):\n    self._testFromCheckpointFloatLabels(n_classes=4)\n\n  def _testFromCheckpointMultiBatch(self, n_classes):\n    # Create initial checkpoint.\n    label = [1, 0]\n    age = [17, 18.5]\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # For binary classifier:\n    #   logits = bias\n    #   logits[0] = -1.\n    #   logits[1] = -1.\n    #   loss = sigmoid_cross_entropy(logits, label)\n    #   so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133\n    #       loss[1] = (1 - 0) * -log ( 1- sigmoid(-1) ) = 0.3132\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label)\n    #   where logits = bias and label = [1, 0]\n    #   so, loss = 1 * -log ( softmax(logits)[label] )\n    if n_classes == 2:\n      expected_loss = (1.3133 + 0.3132)\n    else:\n      # Expand logits since batch_size=2\n      logits = bias * np.ones(shape=(2, 1))\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      expected_loss = expected_loss_0 + expected_loss_1\n\n    mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)\n\n    est = baseline.BaselineClassifier(\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(input_fn=lambda: ({'age': (age)}, (label)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=initial_global_step + num_steps,\n        expected_bias=bias)\n\n  def testBinaryClassesFromCheckpointMultiBatch(self):\n    self._testFromCheckpointMultiBatch(n_classes=2)\n\n  def testMultiClassesFromCheckpointMultiBatch(self):\n    self._testFromCheckpointMultiBatch(n_classes=4)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineClassifierEvaluationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _test_evaluation_for_simple_data(self, n_classes):\n    label = 1\n    age = 1.\n\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = _baseline_classifier_fn(\n        n_classes=n_classes, model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({'age': ((age,),)}, ((label,),)), steps=1)\n\n    if n_classes == 2:\n      # Binary classes: loss = -log(sigmoid(-1)) / batch size = 1.3133\n      # Prediction = sigmoid(-1) = 0.2689\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: 1.3133,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: 1.3133,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,\n          metric_keys.MetricKeys.LABEL_MEAN: 1.,\n          metric_keys.MetricKeys.ACCURACY_BASELINE: 1,\n          metric_keys.MetricKeys.AUC: 0.,\n          metric_keys.MetricKeys.AUC_PR: 1.,\n      }\n    else:\n      # Multi classes: loss = 1 * -log ( softmax(logits)[label] )\n      logits = bias\n      logits_exp = np.exp(logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_loss = -1 * math.log(softmax[label])\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_for_simple_data(self):\n    self._test_evaluation_for_simple_data(n_classes=2)\n\n  def test_multi_classes_evaluation_for_simple_data(self):\n    self._test_evaluation_for_simple_data(n_classes=4)\n\n  def _test_evaluation_batch(self, n_classes):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    label = [1, 0]\n    age = [17., 18.]\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = _baseline_classifier_fn(\n        n_classes=n_classes, model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({'age': (age)}, (label)), steps=1)\n\n    if n_classes == 2:\n      # Logits are (-1., -1.) labels are (1, 0).\n      # Loss is\n      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133\n      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132\n      # Prediction = sigmoid(-1) = 0.2689\n      expected_loss = 1.3133 + 0.3132\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,\n          metric_keys.MetricKeys.ACCURACY: 0.5,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: 0.2689,\n          metric_keys.MetricKeys.LABEL_MEAN: 0.5,\n          metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,\n          metric_keys.MetricKeys.AUC: 0.5,\n          metric_keys.MetricKeys.AUC_PR: 0.75,\n      }\n    else:\n      # Expand logits since batch_size=2\n      logits = bias * np.ones(shape=(2, 1))\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      expected_loss = expected_loss_0 + expected_loss_1\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,\n          metric_keys.MetricKeys.ACCURACY: 0.5,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_batch(self):\n    self._test_evaluation_batch(n_classes=2)\n\n  def test_multi_classes_evaluation_batch(self):\n    self._test_evaluation_batch(n_classes=4)\n\n  def _test_evaluation_weights(self, n_classes):\n    \"\"\"Tests evaluation with weights.\"\"\"\n\n    label = [1, 0]\n    age = [17., 18.]\n    weights = [1., 2.]\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    bias = [-1.0] if n_classes == 2 else [-1.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = _baseline_classifier_fn(\n        n_classes=n_classes, weight_column='w', model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({'age': (age), 'w': (weights)}, (label)), steps=1)\n\n    if n_classes == 2:\n      # Logits are (-1., -1.) labels are (1, 0).\n      # Loss is\n      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133\n      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(-1)) = 0.3132\n      #   weights = [1., 2.]\n      expected_loss = 1.3133 * 1. + 0.3132 * 2.\n      loss_mean = expected_loss / (1.0 + 2.0)\n      label_mean = np.average(label, weights=weights)\n      logits = [-1, -1]\n      logistics = sigmoid(np.array(logits))\n      predictions_mean = np.average(logistics, weights=weights)\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,\n          metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.),\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean,\n          metric_keys.MetricKeys.LABEL_MEAN: label_mean,\n          metric_keys.MetricKeys.ACCURACY_BASELINE:\n              (max(label_mean, 1 - label_mean)),\n          metric_keys.MetricKeys.AUC: 0.5,\n          metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.),\n      }\n    else:\n      # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )\n      # Expand logits since batch_size=2\n      logits = bias * np.ones(shape=(2, 1))\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      loss_mean = np.average([expected_loss_0, expected_loss_1],\n                             weights=weights)\n      expected_loss = loss_mean * np.sum(weights)\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,\n          metric_keys.MetricKeys.ACCURACY: 2. / (1. + 2.),\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_weights(self):\n    self._test_evaluation_weights(n_classes=2)\n\n  def test_multi_classes_evaluation_weights(self):\n    self._test_evaluation_weights(n_classes=4)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineClassifierPredictTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _testPredictions(self, n_classes, label_vocabulary, label_output_fn):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    age = 1.\n\n    bias = [10.0] if n_classes == 2 else [10.0] * n_classes\n\n    with tf.Graph().as_default():\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = _baseline_classifier_fn(\n        label_vocabulary=label_vocabulary,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'age': np.array([[age]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = list(est.predict(input_fn=predict_input_fn))\n\n    if n_classes == 2:\n      scalar_logits = bias[0]\n      two_classes_logits = [0, scalar_logits]\n      two_classes_logits_exp = np.exp(two_classes_logits)\n      softmax = two_classes_logits_exp / two_classes_logits_exp.sum()\n\n      expected_predictions = {\n          'class_ids': [1],\n          'all_class_ids': [0, 1],\n          'classes': [label_output_fn(1)],\n          'all_classes': [label_output_fn(0),\n                          label_output_fn(1)],\n          'logistic': [sigmoid(np.array(scalar_logits))],\n          'logits': [scalar_logits],\n          'probabilities': softmax,\n      }\n    else:\n      onedim_logits = np.array(bias)\n      class_ids = onedim_logits.argmax()\n      all_class_ids = list(range(len(onedim_logits)))\n      logits_exp = np.exp(onedim_logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_predictions = {\n          'class_ids': [class_ids],\n          'all_class_ids': all_class_ids,\n          'classes': [label_output_fn(class_ids)],\n          'all_classes': [label_output_fn(i) for i in all_class_ids],\n          'logits': onedim_logits,\n          'probabilities': softmax,\n      }\n\n    self.assertEqual(1, len(predictions))\n    # assertAllClose cannot handle byte type.\n    self.assertEqual(expected_predictions['classes'], predictions[0]['classes'])\n    expected_predictions.pop('classes')\n    predictions[0].pop('classes')\n    self.assertAllEqual(expected_predictions['all_classes'],\n                        predictions[0]['all_classes'])\n    expected_predictions.pop('all_classes')\n    predictions[0].pop('all_classes')\n    self.assertAllClose(\n        sorted_key_dict(expected_predictions), sorted_key_dict(predictions[0]))\n\n  def testBinaryClassesWithoutLabelVocabulary(self):\n    n_classes = 2\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=None,\n        label_output_fn=lambda x: ('%s' % x).encode())\n\n  def testBinaryClassesWithLabelVocabulary(self):\n    n_classes = 2\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n  def testMultiClassesWithoutLabelVocabulary(self):\n    n_classes = 4\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=None,\n        label_output_fn=lambda x: ('%s' % x).encode())\n\n  def testMultiClassesWithLabelVocabulary(self):\n    n_classes = 4\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineClassifierIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,\n                          predict_input_fn, input_dimension, prediction_length):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = _baseline_classifier_fn(\n        n_classes=n_classes, model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['classes'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, 1), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _test_numpy_input_fn(self, n_classes):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    input_dimension = 4\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n    target = np.array([1] * batch_size)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=target,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=target,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_numpy_input_fn(self):\n    self._test_numpy_input_fn(n_classes=2)\n\n  def test_multi_classes_numpy_input_fn(self):\n    self._test_numpy_input_fn(n_classes=4)\n\n  def _test_pandas_input_fn(self, n_classes):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n\n    # Pandas DataFrame natually supports 1 dim data only.\n    input_dimension = 1\n    batch_size = 10\n    data = np.array([1., 2., 3., 4.], dtype=np.float32)\n    target = np.array([1, 0, 1, 0], dtype=np.int32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(target)\n    prediction_length = 4\n\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_pandas_input_fn(self):\n    self._test_pandas_input_fn(n_classes=2)\n\n  def test_multi_classes_pandas_input_fn(self):\n    self._test_pandas_input_fn(n_classes=4)\n\n  def _test_input_fn_from_parse_example(self, n_classes):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    input_dimension = 2\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n    target = np.array([1] * batch_size, dtype=np.int64)\n\n    serialized_examples = []\n    for x, y in zip(data, target):\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=x)),\n                  'y':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(value=[y])),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([1], tf.dtypes.int64),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_input_fn_from_parse_example(self):\n    self._test_input_fn_from_parse_example(n_classes=2)\n\n  def test_multi_classes_input_fn_from_parse_example(self):\n    self._test_input_fn_from_parse_example(n_classes=4)\n\n\n# Tests for Baseline logit_fn.\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass BaselineLogitFnTest(tf.test.TestCase):\n\n  def test_basic_logit_correctness(self):\n    \"\"\"baseline_logit_fn simply returns the bias variable.\"\"\"\n    with tf.Graph().as_default():\n      logit_fn = baseline._baseline_logit_fn_builder(num_outputs=2)\n      logits = logit_fn(features={'age': [[23.], [31.]]})\n      with tf.compat.v1.variable_scope('baseline', reuse=True):\n        bias_var = tf.compat.v1.get_variable('bias')\n      with tf.compat.v1.Session() as sess:\n        sess.run([tf.compat.v1.initializers.global_variables()])\n        self.assertAllClose([[0., 0.], [0., 0.]], logits.eval())\n        sess.run(bias_var.assign([10., 5.]))\n        self.assertAllClose([[10., 5.], [10., 5.]], logits.eval())\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/dnn_estimator_test_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for DNNEstimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.canned.v1 import dnn_testing_utils_v1\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n\ndef _dnn_estimator_fn(weight_column=None, label_dimension=1, **kwargs):\n  \"\"\"Returns a DNNEstimator that uses regression_head.\"\"\"\n  return dnn.DNNEstimator(\n      head=head_lib._regression_head(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          # Tests in core (from which this test inherits) test the sum loss.\n          loss_reduction=tf.compat.v1.losses.Reduction.SUM),\n      **kwargs)\n\n\ndef _dnn_estimator_classifier_fn(n_classes=3, **kwargs):\n  return dnn.DNNEstimator(\n      head=head_lib._multi_class_head_with_softmax_cross_entropy_loss(\n          n_classes=n_classes),\n      **kwargs)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNEstimatorEvaluateTest(\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNEstimatorPredictTest(dnn_testing_utils_v1.BaseDNNRegressorPredictTest,\n                              tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNEstimatorTrainTest(dnn_testing_utils_v1.BaseDNNRegressorTrainTest,\n                            tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNEstimatorWarmStartingTest(dnn_testing_utils_v1.BaseDNNWarmStartingTest,\n                                   tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNWarmStartingTest.__init__(\n        self, _dnn_estimator_classifier_fn, _dnn_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNEstimatorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, batch_size):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = dnn.DNNEstimator(\n        head=head_lib._regression_head(label_dimension=label_dimension),\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        model_dir=self._model_dir)\n\n    # Train\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # Evaluate\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # Predict\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # Export\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/dnn_linear_combined_estimator_test_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for DNNLinearCombinedEstimatorV1.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.canned import dnn_linear_combined\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.canned.v1 import dnn_testing_utils_v1\nfrom tensorflow_estimator.python.estimator.canned.v1 import linear_testing_utils_v1\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n\ndef _dnn_only_estimator_fn(hidden_units,\n                           feature_columns,\n                           model_dir=None,\n                           label_dimension=1,\n                           weight_column=None,\n                           optimizer='Adagrad',\n                           activation_fn=tf.nn.relu,\n                           dropout=None,\n                           input_layer_partitioner=None,\n                           config=None):\n  return dnn_linear_combined.DNNLinearCombinedEstimator(\n      head=head_lib._regression_head(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          # Tests in core (from which this test inherits) test the sum loss.\n          loss_reduction=tf.compat.v1.losses.Reduction.SUM),\n      model_dir=model_dir,\n      dnn_feature_columns=feature_columns,\n      dnn_optimizer=optimizer,\n      dnn_hidden_units=hidden_units,\n      dnn_activation_fn=activation_fn,\n      dnn_dropout=dropout,\n      input_layer_partitioner=input_layer_partitioner,\n      config=config)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyEstimatorEvaluateTest(\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_only_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyEstimatorPredictTest(\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_only_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyEstimatorTrainTest(dnn_testing_utils_v1.BaseDNNRegressorTrainTest,\n                                tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_only_estimator_fn)\n\n\ndef _linear_only_estimator_fn(feature_columns,\n                              model_dir=None,\n                              label_dimension=1,\n                              weight_column=None,\n                              optimizer='Ftrl',\n                              config=None,\n                              partitioner=None,\n                              sparse_combiner='sum'):\n  return dnn_linear_combined.DNNLinearCombinedEstimator(\n      head=head_lib._regression_head(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          # Tests in core (from which this test inherits) test the sum loss.\n          loss_reduction=tf.compat.v1.losses.Reduction.SUM),\n      model_dir=model_dir,\n      linear_feature_columns=feature_columns,\n      linear_optimizer=optimizer,\n      input_layer_partitioner=partitioner,\n      config=config,\n      linear_sparse_combiner=sparse_combiner)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyEstimatorEvaluateTest(\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_only_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyEstimatorPredictTest(\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_only_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyEstimatorTrainTest(\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_only_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNLinearCombinedEstimatorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, batch_size):\n    linear_feature_columns = [\n        feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    est = dnn_linear_combined.DNNLinearCombinedEstimator(\n        head=head_lib._regression_head(label_dimension=label_dimension),\n        linear_feature_columns=linear_feature_columns,\n        dnn_feature_columns=dnn_feature_columns,\n        dnn_hidden_units=(2, 2),\n        model_dir=self._model_dir)\n\n    # Train\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # Evaluate\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # Predict\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # Export\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/dnn_linear_combined_test_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for v1 version of dnn_linear_combined.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator.canned import dnn_linear_combined\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.canned.v1 import dnn_testing_utils_v1\nfrom tensorflow_estimator.python.estimator.canned.v1 import linear_testing_utils_v1\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n# This is so that we can easily switch between feature_column and\n# feature_column_v2 for testing.\n# Note that following V2 version of tests are for feature_column_v2, not the v2\n# version of canned estimator.\nfeature_column.numeric_column = feature_column._numeric_column\nfeature_column.categorical_column_with_hash_bucket = feature_column._categorical_column_with_hash_bucket  # pylint: disable=line-too-long\nfeature_column.categorical_column_with_vocabulary_list = feature_column._categorical_column_with_vocabulary_list  # pylint: disable=line-too-long\nfeature_column.categorical_column_with_vocabulary_file = feature_column._categorical_column_with_vocabulary_file  # pylint: disable=line-too-long\nfeature_column.embedding_column = feature_column._embedding_column\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyModelFnTest(dnn_testing_utils_v1.BaseDNNModelFnTest,\n                         tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNModelFnTest.__init__(self,\n                                                     self._dnn_only_model_fn)\n\n  def _dnn_only_model_fn(self,\n                         features,\n                         labels,\n                         mode,\n                         head,\n                         hidden_units,\n                         feature_columns,\n                         optimizer='Adagrad',\n                         activation_fn=tf.nn.relu,\n                         dropout=None,\n                         input_layer_partitioner=None,\n                         config=None):\n    return dnn_linear_combined._dnn_linear_combined_model_fn(\n        features=features,\n        labels=labels,\n        mode=mode,\n        head=head,\n        linear_feature_columns=[],\n        dnn_hidden_units=hidden_units,\n        dnn_feature_columns=feature_columns,\n        dnn_optimizer=optimizer,\n        dnn_activation_fn=activation_fn,\n        dnn_dropout=dropout,\n        input_layer_partitioner=input_layer_partitioner,\n        config=config)\n\n\n# A function to mimic linear-regressor init reuse same tests.\ndef _linear_regressor_fn(feature_columns,\n                         model_dir=None,\n                         label_dimension=1,\n                         weight_column=None,\n                         optimizer='Ftrl',\n                         config=None,\n                         partitioner=None,\n                         sparse_combiner='sum'):\n  return dnn_linear_combined.DNNLinearCombinedRegressor(\n      model_dir=model_dir,\n      linear_feature_columns=feature_columns,\n      linear_optimizer=optimizer,\n      label_dimension=label_dimension,\n      weight_column=weight_column,\n      input_layer_partitioner=partitioner,\n      config=config,\n      linear_sparse_combiner=sparse_combiner)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorPartitionerTest(\n    linear_testing_utils_v1.BaseLinearRegressorPartitionerTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPartitionerTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorPartitionerV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorPartitionerTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPartitionerTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorEvaluationTest(\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorEvaluationV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorPredictTest(\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorPredictV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorIntegrationTest(\n    linear_testing_utils_v1.BaseLinearRegressorIntegrationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorIntegrationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorIntegrationV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorIntegrationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorIntegrationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorTrainingTest(\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyRegressorTrainingV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\ndef _linear_classifier_fn(feature_columns,\n                          model_dir=None,\n                          n_classes=2,\n                          weight_column=None,\n                          label_vocabulary=None,\n                          optimizer='Ftrl',\n                          config=None,\n                          partitioner=None,\n                          sparse_combiner='sum'):\n  return dnn_linear_combined.DNNLinearCombinedClassifier(\n      model_dir=model_dir,\n      linear_feature_columns=feature_columns,\n      linear_optimizer=optimizer,\n      n_classes=n_classes,\n      weight_column=weight_column,\n      label_vocabulary=label_vocabulary,\n      input_layer_partitioner=partitioner,\n      config=config,\n      linear_sparse_combiner=sparse_combiner)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyClassifierTrainingTest(\n    linear_testing_utils_v1.BaseLinearClassifierTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierTrainingTest.__init__(\n        self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyClassifierTrainingV2Test(\n    linear_testing_utils_v1.BaseLinearClassifierTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierTrainingTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyClassifierClassesEvaluationTest(\n    linear_testing_utils_v1.BaseLinearClassifierEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierEvaluationTest.__init__(\n        self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyClassifierClassesEvaluationV2Test(\n    linear_testing_utils_v1.BaseLinearClassifierEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierEvaluationTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyClassifierPredictTest(\n    linear_testing_utils_v1.BaseLinearClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierPredictTest.__init__(\n        self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyClassifierPredictV2Test(\n    linear_testing_utils_v1.BaseLinearClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierPredictTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyClassifierIntegrationTest(\n    linear_testing_utils_v1.BaseLinearClassifierIntegrationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierIntegrationTest.__init__(\n        self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearOnlyClassifierIntegrationV2Test(\n    linear_testing_utils_v1.BaseLinearClassifierIntegrationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierIntegrationTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\n@parameterized.parameters((feature_column,), (feature_column_v2,))\nclass DNNLinearCombinedRegressorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow_helper(self, linear_feature_columns,\n                                 dnn_feature_columns, feature_spec,\n                                 train_input_fn, eval_input_fn,\n                                 predict_input_fn, input_dimension,\n                                 label_dimension, batch_size):\n    est = dnn_linear_combined.DNNLinearCombinedRegressor(\n        linear_feature_columns=linear_feature_columns,\n        dnn_hidden_units=(2, 2),\n        dnn_feature_columns=dnn_feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # EXPORT\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, batch_size,\n                          fc_impl):\n    linear_feature_columns = [\n        fc_impl.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        fc_impl.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    self._test_complete_flow_helper(linear_feature_columns, dnn_feature_columns,\n                                    feature_spec, train_input_fn, eval_input_fn,\n                                    predict_input_fn, input_dimension,\n                                    label_dimension, batch_size)\n\n  def _test_complete_flow_mix1(self, train_input_fn, eval_input_fn,\n                               predict_input_fn, input_dimension,\n                               label_dimension, batch_size, fc_impl):\n    del fc_impl\n    linear_feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    self._test_complete_flow_helper(linear_feature_columns, dnn_feature_columns,\n                                    feature_spec, train_input_fn, eval_input_fn,\n                                    predict_input_fn, input_dimension,\n                                    label_dimension, batch_size)\n\n  def _test_complete_flow_mix2(self, train_input_fn, eval_input_fn,\n                               predict_input_fn, input_dimension,\n                               label_dimension, batch_size, fc_impl):\n    del fc_impl\n    linear_feature_columns = [\n        feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    self._test_complete_flow_helper(linear_feature_columns, dnn_feature_columns,\n                                    feature_spec, train_input_fn, eval_input_fn,\n                                    predict_input_fn, input_dimension,\n                                    label_dimension, batch_size)\n\n  def _test_numpy_input_fn_helper(self, fc_impl, fn_to_run):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    fn_to_run(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_numpy_input_fn_basic(self, fc_impl):\n    self._test_numpy_input_fn_helper(fc_impl, self._test_complete_flow)\n\n  def test_numpy_input_fn_mix1(self, fc_impl):\n    self._test_numpy_input_fn_helper(fc_impl, self._test_complete_flow_mix1)\n\n  def test_numpy_input_fn_mix2(self, fc_impl):\n    self._test_numpy_input_fn_helper(fc_impl, self._test_complete_flow_mix2)\n\n  def _test_pandas_input_fn_helper(self, fc_impl, fn_to_run):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    label_dimension = 1\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(data)\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    fn_to_run(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_pandas_input_fn_basic(self, fc_impl):\n    self._test_pandas_input_fn_helper(fc_impl, self._test_complete_flow)\n\n  def test_pandas_input_fn_mix1(self, fc_impl):\n    self._test_pandas_input_fn_helper(fc_impl, self._test_complete_flow_mix1)\n\n  def test_pandas_input_fn_mix2(self, fc_impl):\n    self._test_pandas_input_fn_helper(fc_impl, self._test_complete_flow_mix2)\n\n  def _test_input_fn_from_parse_example_helper(self, fc_impl, fn_to_run):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = linear_testing_utils_v1.queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = linear_testing_utils_v1.queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = linear_testing_utils_v1.queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    fn_to_run(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_input_fn_from_parse_example_basic(self, fc_impl):\n    self._test_input_fn_from_parse_example_helper(fc_impl,\n                                                  self._test_complete_flow)\n\n  def test_input_fn_from_parse_example_mix1(self, fc_impl):\n    self._test_input_fn_from_parse_example_helper(fc_impl,\n                                                  self._test_complete_flow_mix1)\n\n  def test_input_fn_from_parse_example_mix2(self, fc_impl):\n    self._test_input_fn_from_parse_example_helper(fc_impl,\n                                                  self._test_complete_flow_mix2)\n\n\n# A function to mimic dnn-classifier init reuse same tests.\ndef _dnn_classifier_fn(hidden_units,\n                       feature_columns,\n                       model_dir=None,\n                       n_classes=2,\n                       weight_column=None,\n                       label_vocabulary=None,\n                       optimizer='Adagrad',\n                       config=None,\n                       input_layer_partitioner=None):\n  return dnn_linear_combined.DNNLinearCombinedClassifier(\n      model_dir=model_dir,\n      dnn_hidden_units=hidden_units,\n      dnn_feature_columns=feature_columns,\n      dnn_optimizer=optimizer,\n      n_classes=n_classes,\n      weight_column=weight_column,\n      label_vocabulary=label_vocabulary,\n      input_layer_partitioner=input_layer_partitioner,\n      config=config)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyClassifierEvaluateTest(\n    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyClassifierEvaluateV2Test(\n    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyClassifierPredictTest(\n    dnn_testing_utils_v1.BaseDNNClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierPredictTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyClassifierPredictV2Test(\n    dnn_testing_utils_v1.BaseDNNClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierPredictTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyClassifierTrainTest(\n    dnn_testing_utils_v1.BaseDNNClassifierTrainTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierTrainTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyClassifierTrainV2Test(\n    dnn_testing_utils_v1.BaseDNNClassifierTrainTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierTrainTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\n# A function to mimic dnn-regressor init reuse same tests.\ndef _dnn_regressor_fn(hidden_units,\n                      feature_columns,\n                      model_dir=None,\n                      label_dimension=1,\n                      weight_column=None,\n                      optimizer='Adagrad',\n                      config=None,\n                      input_layer_partitioner=None):\n  return dnn_linear_combined.DNNLinearCombinedRegressor(\n      model_dir=model_dir,\n      dnn_hidden_units=hidden_units,\n      dnn_feature_columns=feature_columns,\n      dnn_optimizer=optimizer,\n      label_dimension=label_dimension,\n      weight_column=weight_column,\n      input_layer_partitioner=input_layer_partitioner,\n      config=config)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyRegressorEvaluateTest(\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyRegressorEvaluateV2Test(\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyRegressorPredictTest(\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyRegressorPredictV2Test(\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyRegressorTrainTest(dnn_testing_utils_v1.BaseDNNRegressorTrainTest,\n                                tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNOnlyRegressorTrainV2Test(\n    dnn_testing_utils_v1.BaseDNNRegressorTrainTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\n@parameterized.parameters((feature_column,), (feature_column_v2,))\nclass DNNLinearCombinedClassifierIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _as_label(self, data_in_float):\n    return np.rint(data_in_float).astype(np.int64)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, n_classes, batch_size, fc_impl):\n    linear_feature_columns = [\n        fc_impl.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        fc_impl.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    est = dnn_linear_combined.DNNLinearCombinedClassifier(\n        linear_feature_columns=linear_feature_columns,\n        dnn_hidden_units=(2, 2),\n        dnn_feature_columns=dnn_feature_columns,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predicted_proba = np.array([\n        x[prediction_keys.PredictionKeys.PROBABILITIES]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)\n\n    # EXPORT\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self, fc_impl):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    n_classes = 3\n    input_dimension = 2\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    x_data = data.reshape(batch_size, input_dimension)\n    y_data = self._as_label(np.reshape(data[:batch_size], (batch_size, 1)))\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data},\n        y=y_data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_pandas_input_fn(self, fc_impl):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    input_dimension = 1\n    n_classes = 2\n    batch_size = 10\n    data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(self._as_label(data))\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n  def test_input_fn_from_parse_example(self, fc_impl):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    input_dimension = 2\n    n_classes = 3\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(\n                              value=self._as_label(datum[:1]))),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([1], tf.dtypes.int64),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = linear_testing_utils_v1.queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = linear_testing_utils_v1.queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = linear_testing_utils_v1.queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size,\n        fc_impl=fc_impl)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\n@parameterized.parameters((feature_column,), (feature_column_v2,))\nclass DNNLinearCombinedTests(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _mock_optimizer(self, real_optimizer, var_name_prefix):\n    \"\"\"Verifies global_step is None and var_names start with given prefix.\"\"\"\n\n    def _minimize(loss, global_step=None, var_list=None):\n      self.assertIsNone(global_step)\n      trainable_vars = var_list or tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      var_names = [var.name for var in trainable_vars]\n      self.assertTrue(\n          all([name.startswith(var_name_prefix) for name in var_names]))\n      # var is used to check this op called by training.\n      with ops.name_scope(''):\n        var = tf.Variable(0., name=(var_name_prefix + '_called'))\n      with tf.control_dependencies([var.assign(100.)]):\n        return real_optimizer.minimize(loss, global_step, var_list)\n\n    optimizer_mock = tf.compat.v1.test.mock.NonCallableMagicMock(\n        spec=tf.compat.v1.train.Optimizer, wraps=real_optimizer)\n    optimizer_mock.minimize = tf.compat.v1.test.mock.MagicMock(wraps=_minimize)\n\n    return optimizer_mock\n\n  def test_train_op_calls_both_dnn_and_linear(self, fc_impl):\n    opt = tf.compat.v1.train.GradientDescentOptimizer(1.)\n    x_column = fc_impl.numeric_column('x')\n    input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[0.], [1.]])},\n        y=np.array([[0.], [1.]]),\n        batch_size=1,\n        shuffle=False)\n    est = dnn_linear_combined.DNNLinearCombinedClassifier(\n        linear_feature_columns=[x_column],\n        # verifies linear_optimizer is used only for linear part.\n        linear_optimizer=self._mock_optimizer(opt, 'linear'),\n        dnn_hidden_units=(2, 2),\n        dnn_feature_columns=[x_column],\n        # verifies dnn_optimizer is used only for linear part.\n        dnn_optimizer=self._mock_optimizer(opt, 'dnn'),\n        model_dir=self._model_dir)\n    est.train(input_fn, steps=1)\n    # verifies train_op fires linear minimize op\n    self.assertEqual(100.,\n                     tf.train.load_variable(self._model_dir, 'linear_called'))\n    # verifies train_op fires dnn minimize op\n    self.assertEqual(100., tf.train.load_variable(self._model_dir,\n                                                  'dnn_called'))\n\n  def test_dnn_and_linear_logits_are_added(self, fc_impl):\n    with tf.Graph().as_default():\n      tf.Variable([[1.0]], name='linear/linear_model/x/weights')\n      tf.Variable([2.0], name='linear/linear_model/bias_weights')\n      tf.Variable([[3.0]], name='dnn/hiddenlayer_0/kernel')\n      tf.Variable([4.0], name='dnn/hiddenlayer_0/bias')\n      tf.Variable([[5.0]], name='dnn/logits/kernel')\n      tf.Variable([6.0], name='dnn/logits/bias')\n      tf.Variable(1, name='global_step', dtype=tf.dtypes.int64)\n      linear_testing_utils_v1.save_variables_to_ckpt(self._model_dir)\n\n    x_column = fc_impl.numeric_column('x')\n    est = dnn_linear_combined.DNNLinearCombinedRegressor(\n        linear_feature_columns=[x_column],\n        dnn_hidden_units=[1],\n        dnn_feature_columns=[x_column],\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)\n    # linear logits = 10*1 + 2 = 12\n    # dnn logits = (10*3 + 4)*5 + 6 = 176\n    # logits = dnn + linear = 176 + 12 = 188\n    self.assertAllClose({\n        prediction_keys.PredictionKeys.PREDICTIONS: [188.],\n    }, next(est.predict(input_fn=input_fn)))\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\n@parameterized.parameters((feature_column,), (feature_column_v2,))\nclass DNNLinearCombinedWarmStartingTest(tf.test.TestCase):\n\n  def setUp(self):\n    # Create a directory to save our old checkpoint and vocabularies to.\n    self._ckpt_and_vocab_dir = tempfile.mkdtemp()\n\n    # Make a dummy input_fn.\n    def _input_fn():\n      features = {\n          'age': [[23.], [31.]],\n          'city': [['Palo Alto'], ['Mountain View']],\n      }\n      return features, [0, 1]\n\n    self._input_fn = _input_fn\n\n  def tearDown(self):\n    # Clean up checkpoint / vocab dir.\n    tf.compat.v1.summary.FileWriterCache.clear()\n    shutil.rmtree(self._ckpt_and_vocab_dir)\n\n  def test_classifier_basic_warm_starting(self, fc_impl):\n    \"\"\"Tests correctness of DNNLinearCombinedClassifier default warm-start.\"\"\"\n    age = fc_impl.numeric_column('age')\n    city = fc_impl.embedding_column(\n        fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNLinearCombinedClassifier and train to save a checkpoint.\n    dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifier(\n        linear_feature_columns=[age],\n        dnn_feature_columns=[city],\n        dnn_hidden_units=[256, 128],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        linear_optimizer='SGD',\n        dnn_optimizer='SGD')\n    dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNLinearCombinedClassifier, warm-started from the first.\n    # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't\n    # have accumulator values that change).\n    warm_started_dnn_lc_classifier = (\n        dnn_linear_combined.DNNLinearCombinedClassifier(\n            linear_feature_columns=[age],\n            dnn_feature_columns=[city],\n            dnn_hidden_units=[256, 128],\n            n_classes=4,\n            linear_optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n                learning_rate=0.0),\n            dnn_optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n                learning_rate=0.0),\n            warm_start_from=dnn_lc_classifier.model_dir))\n\n    warm_started_dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_lc_classifier.get_variable_names():\n      self.assertAllClose(\n          dnn_lc_classifier.get_variable_value(variable_name),\n          warm_started_dnn_lc_classifier.get_variable_value(variable_name))\n\n  def test_regressor_basic_warm_starting(self, fc_impl):\n    \"\"\"Tests correctness of DNNLinearCombinedRegressor default warm-start.\"\"\"\n    age = fc_impl.numeric_column('age')\n    city = fc_impl.embedding_column(\n        fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNLinearCombinedRegressor and train to save a checkpoint.\n    dnn_lc_regressor = dnn_linear_combined.DNNLinearCombinedRegressor(\n        linear_feature_columns=[age],\n        dnn_feature_columns=[city],\n        dnn_hidden_units=[256, 128],\n        model_dir=self._ckpt_and_vocab_dir,\n        linear_optimizer='SGD',\n        dnn_optimizer='SGD')\n    dnn_lc_regressor.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNLinearCombinedRegressor, warm-started from the first.\n    # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't\n    # have accumulator values that change).\n    warm_started_dnn_lc_regressor = (\n        dnn_linear_combined.DNNLinearCombinedRegressor(\n            linear_feature_columns=[age],\n            dnn_feature_columns=[city],\n            dnn_hidden_units=[256, 128],\n            linear_optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n                learning_rate=0.0),\n            dnn_optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n                learning_rate=0.0),\n            warm_start_from=dnn_lc_regressor.model_dir))\n\n    warm_started_dnn_lc_regressor.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_lc_regressor.get_variable_names():\n      self.assertAllClose(\n          dnn_lc_regressor.get_variable_value(variable_name),\n          warm_started_dnn_lc_regressor.get_variable_value(variable_name))\n\n  def test_warm_starting_selective_variables(self, fc_impl):\n    \"\"\"Tests selecting variables to warm-start.\"\"\"\n    age = fc_impl.numeric_column('age')\n    city = fc_impl.embedding_column(\n        fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNLinearCombinedClassifier and train to save a checkpoint.\n    dnn_lc_classifier = dnn_linear_combined.DNNLinearCombinedClassifier(\n        linear_feature_columns=[age],\n        dnn_feature_columns=[city],\n        dnn_hidden_units=[256, 128],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        linear_optimizer='SGD',\n        dnn_optimizer='SGD')\n    dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNLinearCombinedClassifier, warm-started from the first.\n    # Use a learning_rate = 0.0 optimizer to check values (use SGD so we don't\n    # have accumulator values that change).\n    warm_started_dnn_lc_classifier = (\n        dnn_linear_combined.DNNLinearCombinedClassifier(\n            linear_feature_columns=[age],\n            dnn_feature_columns=[city],\n            dnn_hidden_units=[256, 128],\n            n_classes=4,\n            linear_optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n                learning_rate=0.0),\n            dnn_optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n                learning_rate=0.0),\n            # The provided regular expression will only warm-start the deep\n            # portion of the model.\n            warm_start_from=estimator.WarmStartSettings(\n                ckpt_to_initialize_from=dnn_lc_classifier.model_dir,\n                vars_to_warm_start='.*(dnn).*')))\n\n    warm_started_dnn_lc_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_lc_classifier.get_variable_names():\n      if 'dnn' in variable_name:\n        self.assertAllClose(\n            dnn_lc_classifier.get_variable_value(variable_name),\n            warm_started_dnn_lc_classifier.get_variable_value(variable_name))\n      elif 'linear' in variable_name:\n        linear_values = warm_started_dnn_lc_classifier.get_variable_value(\n            variable_name)\n        # Since they're not warm-started, the linear weights will be\n        # zero-initialized.\n        self.assertAllClose(np.zeros_like(linear_values), linear_values)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/dnn_test_fc_v1_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for dnn.py with feature_column_v1.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.canned.v1 import dnn_testing_utils_v1\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n# Uses feature_column_v1 for testing.\nfeature_column.numeric_column = feature_column._numeric_column  # pylint: disable=protected-access\n\n\ndef _dnn_classifier_fn(*args, **kwargs):\n  return dnn.DNNClassifier(*args, **kwargs)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNModelFnTest(dnn_testing_utils_v1.BaseDNNModelFnTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNModelFnTest.__init__(\n        self, dnn._dnn_model_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNLogitFnTest(dnn_testing_utils_v1.BaseDNNLogitFnTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNLogitFnTest.__init__(\n        self, dnn.dnn_logit_fn_builder, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNWarmStartingTest(dnn_testing_utils_v1.BaseDNNWarmStartingTest,\n                          tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNWarmStartingTest.__init__(\n        self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNClassifierEvaluateTest(\n    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNClassifierPredictTest(\n    dnn_testing_utils_v1.BaseDNNClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierPredictTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNClassifierTrainTest(dnn_testing_utils_v1.BaseDNNClassifierTrainTest,\n                             tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierTrainTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column)\n\n\ndef _dnn_regressor_fn(*args, **kwargs):\n  return dnn.DNNRegressor(*args, **kwargs)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNRegressorEvaluateTest(\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNRegressorPredictTest(dnn_testing_utils_v1.BaseDNNRegressorPredictTest,\n                              tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNRegressorTrainTest(dnn_testing_utils_v1.BaseDNNRegressorTrainTest,\n                            tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column)\n\n\ndef _queue_parsed_features(feature_map):\n  tensors_to_enqueue = []\n  keys = []\n  for key, tensor in six.iteritems(feature_map):\n    keys.append(key)\n    tensors_to_enqueue.append(tensor)\n  queue_dtypes = [x.dtype for x in tensors_to_enqueue]\n  input_queue = tf.queue.FIFOQueue(capacity=100, dtypes=queue_dtypes)\n  tf.compat.v1.train.queue_runner.add_queue_runner(\n      tf.compat.v1.train.queue_runner.QueueRunner(\n          input_queue, [input_queue.enqueue(tensors_to_enqueue)]))\n  dequeued_tensors = input_queue.dequeue()\n  return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNRegressorIntegrationTest(tf.test.TestCase, parameterized.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, batch_size):\n    feature_columns = [\n        feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n\n    est = dnn.DNNRegressor(\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUATE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    label_dimension = 1\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(data)\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNClassifierIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _as_label(self, data_in_float):\n    return np.rint(data_in_float).astype(np.int64)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, n_classes, batch_size):\n    feature_columns = [\n        feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n\n    est = dnn.DNNClassifier(\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUATE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predicted_proba = np.array([\n        x[prediction_keys.PredictionKeys.PROBABILITIES]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)\n\n    # EXPORT\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    n_classes = 3\n    input_dimension = 2\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    x_data = data.reshape(batch_size, input_dimension)\n    y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data},\n        y=y_data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    input_dimension = 1\n    n_classes = 3\n    batch_size = 10\n    data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(self._as_label(data))\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    input_dimension = 2\n    n_classes = 3\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(\n                              value=self._as_label(datum[:1]))),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([1], tf.dtypes.int64),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/dnn_test_fc_v2_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for dnn.py with feature_column_v2.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.canned.v1 import dnn_testing_utils_v1\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n\ndef _dnn_classifier_fn(*args, **kwargs):\n  return dnn.DNNClassifier(*args, **kwargs)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNModelFnV2Test(dnn_testing_utils_v1.BaseDNNModelFnTest,\n                       tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNModelFnTest.__init__(\n        self, dnn._dnn_model_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNLogitFnV2Test(dnn_testing_utils_v1.BaseDNNLogitFnTest,\n                       tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNLogitFnTest.__init__(\n        self, dnn.dnn_logit_fn_builder, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNWarmStartingV2Test(dnn_testing_utils_v1.BaseDNNWarmStartingTest,\n                            tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNWarmStartingTest.__init__(\n        self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNClassifierEvaluateV2Test(\n    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierEvaluateTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNClassifierPredictV2Test(\n    dnn_testing_utils_v1.BaseDNNClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierPredictTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNClassifierTrainV2Test(dnn_testing_utils_v1.BaseDNNClassifierTrainTest,\n                               tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNClassifierTrainTest.__init__(\n        self, _dnn_classifier_fn, fc_impl=feature_column_v2)\n\n\ndef _dnn_regressor_fn(*args, **kwargs):\n  return dnn.DNNRegressor(*args, **kwargs)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNRegressorEvaluateV2Test(\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorEvaluateTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNRegressorPredictV2Test(\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorPredictTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNRegressorTrainV2Test(dnn_testing_utils_v1.BaseDNNRegressorTrainTest,\n                              tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    dnn_testing_utils_v1.BaseDNNRegressorTrainTest.__init__(\n        self, _dnn_regressor_fn, fc_impl=feature_column_v2)\n\n\ndef _queue_parsed_features(feature_map):\n  tensors_to_enqueue = []\n  keys = []\n  for key, tensor in six.iteritems(feature_map):\n    keys.append(key)\n    tensors_to_enqueue.append(tensor)\n  queue_dtypes = [x.dtype for x in tensors_to_enqueue]\n  input_queue = tf.queue.FIFOQueue(capacity=100, dtypes=queue_dtypes)\n  tf.compat.v1.train.queue_runner.add_queue_runner(\n      tf.compat.v1.train.queue_runner.QueueRunner(\n          input_queue, [input_queue.enqueue(tensors_to_enqueue)]))\n  dequeued_tensors = input_queue.dequeue()\n  return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNRegressorIntegrationTest(tf.test.TestCase, parameterized.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, batch_size):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n\n    est = dnn.DNNRegressor(\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUATE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    label_dimension = 1\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(data)\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass DNNClassifierIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _as_label(self, data_in_float):\n    return np.rint(data_in_float).astype(np.int64)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, n_classes, batch_size):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n\n    est = dnn.DNNClassifier(\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # EVALUATE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # PREDICT\n    predicted_proba = np.array([\n        x[prediction_keys.PredictionKeys.PROBABILITIES]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)\n\n    # EXPORT\n    feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    n_classes = 3\n    input_dimension = 2\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    x_data = data.reshape(batch_size, input_dimension)\n    y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data},\n        y=y_data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': x_data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n    input_dimension = 1\n    n_classes = 3\n    batch_size = 10\n    data = np.linspace(0., n_classes - 1., batch_size, dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(self._as_label(data))\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    input_dimension = 2\n    n_classes = 3\n    batch_size = 10\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(\n                              value=self._as_label(datum[:1]))),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([1], tf.dtypes.int64),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = _queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        n_classes=n_classes,\n        batch_size=batch_size)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/dnn_testing_utils_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utils to be used in testing DNN estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n# pylint rules which are disabled by default for test files.\n# pylint: disable=invalid-name,protected-access,missing-docstring\n\n# Names of variables created by model.\nLEARNING_RATE_NAME = 'dnn/regression_head/dnn/learning_rate'\nHIDDEN_WEIGHTS_NAME_PATTERN = 'dnn/hiddenlayer_%d/kernel'\nHIDDEN_BIASES_NAME_PATTERN = 'dnn/hiddenlayer_%d/bias'\nBATCH_NORM_BETA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/beta'\nBATCH_NORM_GAMMA_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/gamma'\nBATCH_NORM_MEAN_NAME_PATTERN = 'dnn/hiddenlayer_%d/batchnorm_%d/moving_mean'\nBATCH_NORM_VARIANCE_NAME_PATTERN = (\n    'dnn/hiddenlayer_%d/batchnorm_%d/moving_variance')\nLOGITS_WEIGHTS_NAME = 'dnn/logits/kernel'\nLOGITS_BIASES_NAME = 'dnn/logits/bias'\nOCCUPATION_EMBEDDING_NAME = ('dnn/input_from_feature_columns/input_layer/'\n                             'occupation_embedding/embedding_weights')\nCITY_EMBEDDING_NAME = ('dnn/input_from_feature_columns/input_layer/'\n                       'city_embedding/embedding_weights')\n\n# This is so that we can easily switch between feature_column and\n# feature_column_v2 for testing.\nfeature_column.numeric_column = feature_column._numeric_column\nfeature_column.categorical_column_with_hash_bucket = feature_column._categorical_column_with_hash_bucket  # pylint: disable=line-too-long\nfeature_column.categorical_column_with_vocabulary_list = feature_column._categorical_column_with_vocabulary_list  # pylint: disable=line-too-long\nfeature_column.categorical_column_with_vocabulary_file = feature_column._categorical_column_with_vocabulary_file  # pylint: disable=line-too-long\nfeature_column.embedding_column = feature_column._embedding_column\n\n\ndef assert_close(expected, actual, rtol=1e-04, message='', name='assert_close'):\n  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:\n    expected = ops.convert_to_tensor(expected, name='expected')\n    actual = ops.convert_to_tensor(actual, name='actual')\n    rdiff = tf.math.abs((expected - actual) / expected, 'diff')\n    rtol = ops.convert_to_tensor(rtol, name='rtol')\n    return tf.compat.v1.debugging.assert_less(\n        rdiff,\n        rtol,\n        data=(message, 'Condition expected =~ actual did not hold element-wise:'\n              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,\n              'rtol = ', rtol,),\n        summarize=expected.get_shape().num_elements(),\n        name=scope)\n\n\ndef create_checkpoint(weights_and_biases,\n                      global_step,\n                      model_dir,\n                      batch_norm_vars=None):\n  \"\"\"Create checkpoint file with provided model weights.\n\n  Args:\n    weights_and_biases: Iterable of tuples of weight and bias values.\n    global_step: Initial global step to save in checkpoint.\n    model_dir: Directory into which checkpoint is saved.\n    batch_norm_vars: Variables used for batch normalization.\n  \"\"\"\n  weights, biases = zip(*weights_and_biases)\n  if batch_norm_vars:\n    assert len(batch_norm_vars) == len(weights_and_biases) - 1\n    (bn_betas, bn_gammas, bn_means, bn_variances) = zip(*batch_norm_vars)\n  model_weights = {}\n\n  # Hidden layer weights.\n  for i in range(0, len(weights) - 1):\n    model_weights[HIDDEN_WEIGHTS_NAME_PATTERN % i] = weights[i]\n    model_weights[HIDDEN_BIASES_NAME_PATTERN % i] = biases[i]\n    if batch_norm_vars:\n      model_weights[BATCH_NORM_BETA_NAME_PATTERN % (i, i)] = bn_betas[i]\n      model_weights[BATCH_NORM_GAMMA_NAME_PATTERN % (i, i)] = bn_gammas[i]\n      model_weights[BATCH_NORM_MEAN_NAME_PATTERN % (i, i)] = bn_means[i]\n      model_weights[BATCH_NORM_VARIANCE_NAME_PATTERN % (i, i)] = bn_variances[i]\n\n  # Output layer weights.\n  model_weights[LOGITS_WEIGHTS_NAME] = weights[-1]\n  model_weights[LOGITS_BIASES_NAME] = biases[-1]\n\n  with tf.Graph().as_default():\n    # Create model variables.\n    for k, v in six.iteritems(model_weights):\n      tf.Variable(v, name=k, dtype=tf.dtypes.float32)\n\n    # Create non-model variables.\n    global_step_var = tf.compat.v1.train.create_global_step()\n\n    # Initialize vars and save checkpoint.\n    with tf.compat.v1.Session() as sess:\n      tf.compat.v1.initializers.global_variables().run()\n      global_step_var.assign(global_step).eval()\n      tf.compat.v1.train.Saver().save(sess,\n                                      os.path.join(model_dir, 'model.ckpt'))\n\n\ndef mock_head(testcase, hidden_units, logits_dimension, expected_logits):\n  \"\"\"Returns a mock head that validates logits values and variable names.\"\"\"\n  hidden_weights_names = [(HIDDEN_WEIGHTS_NAME_PATTERN + '/part_0:0') % i\n                          for i in range(len(hidden_units))]\n  hidden_biases_names = [(HIDDEN_BIASES_NAME_PATTERN + '/part_0:0') % i\n                         for i in range(len(hidden_units))]\n  expected_var_names = (\n      hidden_weights_names + hidden_biases_names +\n      [LOGITS_WEIGHTS_NAME + '/part_0:0', LOGITS_BIASES_NAME + '/part_0:0'])\n\n  def _create_tpu_estimator_spec(features,\n                                 mode,\n                                 logits,\n                                 labels,\n                                 train_op_fn=None,\n                                 optimizer=None):\n    del features, labels  # Not used.\n    trainable_vars = tf.compat.v1.get_collection(\n        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n    testcase.assertItemsEqual(expected_var_names,\n                              [var.name for var in trainable_vars])\n    loss = tf.constant(1.)\n    assert_logits = assert_close(\n        expected_logits, logits, message='Failed for mode={}. '.format(mode))\n    with tf.control_dependencies([assert_logits]):\n      if mode == ModeKeys.TRAIN:\n        if train_op_fn is not None:\n          train_op = train_op_fn(loss)\n        elif optimizer is not None:\n          train_op = optimizer.minimize(loss, global_step=None)\n        return model_fn._TPUEstimatorSpec(\n            mode=mode, loss=loss, train_op=train_op)\n      elif mode == ModeKeys.EVAL:\n        return model_fn._TPUEstimatorSpec(mode=mode, loss=tf.identity(loss))\n      elif mode == ModeKeys.PREDICT:\n        return model_fn._TPUEstimatorSpec(\n            mode=mode, predictions={'logits': tf.identity(logits)})\n      else:\n        testcase.fail('Invalid mode: {}'.format(mode))\n\n  def _create_estimator_spec(features,\n                             mode,\n                             logits,\n                             labels,\n                             train_op_fn=None,\n                             optimizer=None):\n    tpu_spec = _create_tpu_estimator_spec(features, mode, logits, labels,\n                                          train_op_fn, optimizer)\n    return tpu_spec.as_estimator_spec()\n\n  head = tf.compat.v1.test.mock.NonCallableMagicMock(spec=head_lib._Head)\n  head.logits_dimension = logits_dimension\n  head._create_tpu_estimator_spec = tf.compat.v1.test.mock.MagicMock(\n      wraps=_create_tpu_estimator_spec)\n  head.create_estimator_spec = tf.compat.v1.test.mock.MagicMock(\n      wraps=_create_estimator_spec)\n\n  return head\n\n\ndef mock_optimizer(testcase, hidden_units, expected_loss=None):\n  \"\"\"Creates a mock optimizer to test the train method.\n\n  Args:\n    testcase: A TestCase instance.\n    hidden_units: Iterable of integer sizes for the hidden layers.\n    expected_loss: If given, will assert the loss value.\n\n  Returns:\n    A mock Optimizer.\n  \"\"\"\n  hidden_weights_names = [(HIDDEN_WEIGHTS_NAME_PATTERN + '/part_0:0') % i\n                          for i in range(len(hidden_units))]\n  hidden_biases_names = [(HIDDEN_BIASES_NAME_PATTERN + '/part_0:0') % i\n                         for i in range(len(hidden_units))]\n  expected_var_names = (\n      hidden_weights_names + hidden_biases_names +\n      [LOGITS_WEIGHTS_NAME + '/part_0:0', LOGITS_BIASES_NAME + '/part_0:0'])\n\n  def _minimize(loss, global_step=None, var_list=None):\n    \"\"\"Mock of optimizer.minimize.\"\"\"\n    trainable_vars = var_list or tf.compat.v1.get_collection(\n        tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n    testcase.assertItemsEqual(expected_var_names,\n                              [var.name for var in trainable_vars])\n\n    # Verify loss. We can't check the value directly, so we add an assert op.\n    testcase.assertEquals(0, loss.shape.ndims)\n    if expected_loss is None:\n      if global_step is not None:\n        return tf.compat.v1.assign_add(global_step, 1).op\n      return tf.no_op()\n    assert_loss = assert_close(\n        tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n        loss,\n        name='assert_loss')\n    with tf.control_dependencies((assert_loss,)):\n      if global_step is not None:\n        return tf.compat.v1.assign_add(global_step, 1).op\n      return tf.no_op()\n\n  optimizer_mock = tf.compat.v1.test.mock.NonCallableMagicMock(\n      spec=tf.compat.v1.train.Optimizer,\n      wraps=tf.compat.v1.train.Optimizer(\n          use_locking=False, name='my_optimizer'))\n  optimizer_mock.minimize = tf.compat.v1.test.mock.MagicMock(wraps=_minimize)\n\n  return optimizer_mock\n\n\nclass BaseDNNModelFnTest(object):\n  \"\"\"Tests that _dnn_model_fn passes expected logits to mock head.\"\"\"\n\n  def __init__(self, dnn_model_fn, fc_impl=feature_column):\n    self._dnn_model_fn = dnn_model_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_logits(self, mode, hidden_units, logits_dimension, inputs,\n                   expected_logits):\n    \"\"\"Tests that the expected logits are passed to mock head.\"\"\"\n    with tf.Graph().as_default():\n      tf.compat.v1.train.create_global_step()\n      head = mock_head(\n          self,\n          hidden_units=hidden_units,\n          logits_dimension=logits_dimension,\n          expected_logits=expected_logits)\n      estimator_spec = self._dnn_model_fn(\n          features={'age': tf.constant(inputs)},\n          labels=tf.constant([[1]]),\n          mode=mode,\n          head=head,\n          hidden_units=hidden_units,\n          feature_columns=[\n              self._fc_impl.numeric_column(\n                  'age', shape=np.array(inputs).shape[1:])\n          ],\n          optimizer=mock_optimizer(self, hidden_units))\n      with tf.compat.v1.train.MonitoredTrainingSession(\n          checkpoint_dir=self._model_dir) as sess:\n        if mode == ModeKeys.TRAIN:\n          sess.run(estimator_spec.train_op)\n        elif mode == ModeKeys.EVAL:\n          sess.run(estimator_spec.loss)\n        elif mode == ModeKeys.PREDICT:\n          sess.run(estimator_spec.predictions)\n        else:\n          self.fail('Invalid mode: {}'.format(mode))\n\n  def test_one_dim_logits(self):\n    \"\"\"Tests one-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]\n                   = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]\n    logits = [[-1*2.38 +1*0 +0.3]] = [[-2.08]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=1,\n          inputs=[[10.]],\n          expected_logits=[[-2.08]])\n\n  def test_multi_dim_logits(self):\n    \"\"\"Tests multi-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]\n                   = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]\n    logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38]]\n           = [[-2.08, 2.08, 1.19]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10.]],\n          expected_logits=[[-2.08, 2.08, 1.19]])\n\n  def test_multi_example_multi_dim_logits(self):\n    \"\"\"Tests multiple examples and multi-dimensional logits.\n\n    input_layer = [[10], [5]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)],\n                      [relu(0.6*5 +0.1), relu(0.5*5 -0.1)]]\n                   = [[6.1, 4.9], [3.1, 2.4]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)],\n                      [relu(1*3.1 -0.8*2.4 +0.2), relu(0.8*3.1 -1*2.4 -0.1)]]\n                   = [[2.38, 0], [1.38, 0]]\n    logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38],\n              [-1*1.38 +0.3, 1*1.38 -0.3, 0.5*1.38]]\n           = [[-2.08, 2.08, 1.19], [-1.08, 1.08, 0.69]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10.], [5.]],\n          expected_logits=[[-2.08, 2.08, 1.19], [-1.08, 1.08, .69]])\n\n  def test_multi_dim_input_one_dim_logits(self):\n    \"\"\"Tests multi-dimensional inputs and one-dimensional logits.\n\n    input_layer = [[10, 8]]\n    hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]\n                   = [[1.3, 0.9]]\n    hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]\n                   = [[0.78, relu(-0.06)]] = [[0.78, 0]]\n    logits = [[-1*0.78 +1*0 +0.3]] = [[-0.48]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=1,\n          inputs=[[10., 8.]],\n          expected_logits=[[-0.48]])\n\n  def test_multi_dim_input_multi_dim_logits(self):\n    \"\"\"Tests multi-dimensional inputs and multi-dimensional logits.\n\n    input_layer = [[10, 8]]\n    hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]\n                   = [[1.3, 0.9]]\n    hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]\n                   = [[0.78, relu(-0.06)]] = [[0.78, 0]]\n    logits = [[-1*0.78 + 0.3, 1*0.78 -0.3, 0.5*0.78]] = [[-0.48, 0.48, 0.39]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10., 8.]],\n          expected_logits=[[-0.48, 0.48, 0.39]])\n\n  def test_multi_feature_column_multi_dim_logits(self):\n    \"\"\"Tests multiple feature columns and multi-dimensional logits.\n\n    All numbers are the same as test_multi_dim_input_multi_dim_logits. The only\n    difference is that the input consists of two 1D feature columns, instead of\n    one 2D feature column.\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[-0.48, 0.48, 0.39]]\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      with tf.Graph().as_default():\n        tf.compat.v1.train.create_global_step()\n        head = mock_head(\n            self,\n            hidden_units=hidden_units,\n            logits_dimension=logits_dimension,\n            expected_logits=expected_logits)\n        estimator_spec = self._dnn_model_fn(\n            features={\n                'age': tf.constant(inputs[0]),\n                'height': tf.constant(inputs[1])\n            },\n            labels=tf.constant([[1]]),\n            mode=mode,\n            head=head,\n            hidden_units=hidden_units,\n            feature_columns=[\n                self._fc_impl.numeric_column('age'),\n                self._fc_impl.numeric_column('height')\n            ],\n            optimizer=mock_optimizer(self, hidden_units))\n        with tf.compat.v1.train.MonitoredTrainingSession(\n            checkpoint_dir=self._model_dir) as sess:\n          if mode == ModeKeys.TRAIN:\n            sess.run(estimator_spec.train_op)\n          elif mode == ModeKeys.EVAL:\n            sess.run(estimator_spec.loss)\n          elif mode == ModeKeys.PREDICT:\n            sess.run(estimator_spec.predictions)\n          else:\n            self.fail('Invalid mode: {}'.format(mode))\n\n  def test_multi_feature_column_mix_multi_dim_logits(self):\n    \"\"\"Tests multiple feature columns and multi-dimensional logits.\n\n    All numbers are the same as test_multi_dim_input_multi_dim_logits. The only\n    difference is that the input consists of two 1D feature columns, instead of\n    one 2D feature column.\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[-0.48, 0.48, 0.39]]\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      with tf.Graph().as_default():\n        tf.compat.v1.train.create_global_step()\n        head = mock_head(\n            self,\n            hidden_units=hidden_units,\n            logits_dimension=logits_dimension,\n            expected_logits=expected_logits)\n        estimator_spec = self._dnn_model_fn(\n            features={\n                'age': tf.constant(inputs[0]),\n                'height': tf.constant(inputs[1])\n            },\n            labels=tf.constant([[1]]),\n            mode=mode,\n            head=head,\n            hidden_units=hidden_units,\n            feature_columns=[\n                feature_column.numeric_column('age'),\n                tf.feature_column.numeric_column('height')\n            ],\n            optimizer=mock_optimizer(self, hidden_units))\n        with tf.compat.v1.train.MonitoredTrainingSession(\n            checkpoint_dir=self._model_dir) as sess:\n          if mode == ModeKeys.TRAIN:\n            sess.run(estimator_spec.train_op)\n          elif mode == ModeKeys.EVAL:\n            sess.run(estimator_spec.loss)\n          elif mode == ModeKeys.PREDICT:\n            sess.run(estimator_spec.predictions)\n          else:\n            self.fail('Invalid mode: {}'.format(mode))\n\n  def test_features_tensor_raises_value_error(self):\n    \"\"\"Tests that passing a Tensor for features raises a ValueError.\"\"\"\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[0, 0, 0]]\n\n    with tf.Graph().as_default():\n      tf.compat.v1.train.create_global_step()\n      head = mock_head(\n          self,\n          hidden_units=hidden_units,\n          logits_dimension=logits_dimension,\n          expected_logits=expected_logits)\n      with self.assertRaisesRegexp(ValueError, 'features should be a dict'):\n        self._dnn_model_fn(\n            features=tf.constant(inputs),\n            labels=tf.constant([[1]]),\n            mode=ModeKeys.TRAIN,\n            head=head,\n            hidden_units=hidden_units,\n            feature_columns=[\n                self._fc_impl.numeric_column(\n                    'age', shape=np.array(inputs).shape[1:])\n            ],\n            optimizer=mock_optimizer(self, hidden_units))\n\n\nclass BaseDNNLogitFnTest(object):\n  \"\"\"Tests correctness of logits calculated from _dnn_logit_fn_builder.\"\"\"\n\n  def __init__(self, dnn_logit_fn_builder, fc_impl=feature_column):\n    self._dnn_logit_fn_builder = dnn_logit_fn_builder\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_logits(self,\n                   mode,\n                   hidden_units,\n                   logits_dimension,\n                   inputs,\n                   expected_logits,\n                   batch_norm=False):\n    \"\"\"Tests that the expected logits are calculated.\"\"\"\n    with tf.Graph().as_default():\n      # Global step needed for MonitoredSession, which is in turn used to\n      # explicitly set variable weights through a checkpoint.\n      tf.compat.v1.train.create_global_step()\n      # Use a variable scope here with 'dnn', emulating the dnn model_fn, so\n      # the checkpoint naming is shared.\n      with tf.compat.v1.variable_scope('dnn'):\n        input_layer_partitioner = (\n            tf.compat.v1.min_max_variable_partitioner(\n                max_partitions=0, min_slice_size=64 << 20))\n        logit_fn = self._dnn_logit_fn_builder(\n            units=logits_dimension,\n            hidden_units=hidden_units,\n            feature_columns=[\n                self._fc_impl.numeric_column(\n                    'age', shape=np.array(inputs).shape[1:])\n            ],\n            activation_fn=tf.nn.relu,\n            dropout=None,\n            input_layer_partitioner=input_layer_partitioner,\n            batch_norm=batch_norm)\n        logits = logit_fn(features={'age': tf.constant(inputs)}, mode=mode)\n        with tf.compat.v1.train.MonitoredTrainingSession(\n            checkpoint_dir=self._model_dir) as sess:\n          self.assertAllClose(expected_logits, sess.run(logits))\n\n  def test_one_dim_logits(self):\n    \"\"\"Tests one-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]\n                   = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]\n    logits = [[-1*2.38 +1*0 +0.3]] = [[-2.08]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=1,\n          inputs=[[10.]],\n          expected_logits=[[-2.08]])\n\n  def test_one_dim_logits_with_batch_norm(self):\n    \"\"\"Tests one-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +1), relu(0.5*10 -1)]] = [[7, 4]]\n    hidden_layer_0 = [[relu(0.6*20 +1), relu(0.5*20 -1)]] = [[13, 9]]\n\n    batch_norm_0, training (epsilon = 0.001):\n      mean1 = 1/2*(7+13) = 10,\n      variance1 = 1/2*(3^2+3^2) = 9\n      x11 = (7-10)/sqrt(9+0.001) = -0.999944449,\n      x21 = (13-10)/sqrt(9+0.001) = 0.999944449,\n\n      mean2 = 1/2*(4+9) = 6.5,\n      variance2 = 1/2*(2.5^2+.2.5^2) = 6.25\n      x12 = (4-6.5)/sqrt(6.25+0.001) = -0.99992001,\n      x22 = (9-6.5)/sqrt(6.25+0.001) = 0.99992001,\n\n    logits = [[-1*(-0.999944449) + 2*(-0.99992001) + 0.3],\n              [-1*0.999944449 + 2*0.99992001 + 0.3]]\n           = [[-0.699895571],[1.299895571]]\n\n    batch_norm_0, not training (epsilon = 0.001):\n      moving_mean1 = 0, moving_variance1 = 1\n      x11 = (7-0)/sqrt(1+0.001) = 6.996502623,\n      x21 = (13-0)/sqrt(1+0.001) = 12.993504871,\n      moving_mean2 = 0, moving_variance2 = 1\n      x12 = (4-0)/sqrt(1+0.001) = 3.998001499,\n      x22 = (9-0)/sqrt(1+0.001) = 8.995503372,\n\n    logits = [[-1*6.996502623 + 2*3.998001499 + 0.3],\n              [-1*12.993504871 + 2*8.995503372 + 0.3]]\n           = [[1.299500375],[5.297501873]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint(\n        (\n            ([[.6, .5]], [1., -1.]),\n            ([[-1.], [2.]], [.3]),\n        ),\n        base_global_step,\n        self._model_dir,\n        batch_norm_vars=(\n            [\n                [0, 0],  # beta.\n                [1, 1],  # gamma.\n                [0, 0],  # moving mean.\n                [1, 1],  # moving variance.\n            ],))\n    self._test_logits(\n        ModeKeys.TRAIN,\n        hidden_units=[2],\n        logits_dimension=1,\n        inputs=[[10.], [20.]],\n        expected_logits=[[-0.699895571], [1.299895571]],\n        batch_norm=True)\n    for mode in [ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=[2],\n          logits_dimension=1,\n          inputs=[[10.], [20.]],\n          expected_logits=[[1.299500375], [5.297501873]],\n          batch_norm=True)\n\n  def test_multi_dim_logits(self):\n    \"\"\"Tests multi-dimensional logits.\n\n    input_layer = [[10]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)]] = [[6.1, 4.9]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)]]\n                   = [[relu(2.38), relu(-0.12)]] = [[2.38, 0]]\n    logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38]]\n           = [[-2.08, 2.08, 1.19]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10.]],\n          expected_logits=[[-2.08, 2.08, 1.19]])\n\n  def test_multi_example_multi_dim_logits(self):\n    \"\"\"Tests multiple examples and multi-dimensional logits.\n\n    input_layer = [[10], [5]]\n    hidden_layer_0 = [[relu(0.6*10 +0.1), relu(0.5*10 -0.1)],\n                      [relu(0.6*5 +0.1), relu(0.5*5 -0.1)]]\n                   = [[6.1, 4.9], [3.1, 2.4]]\n    hidden_layer_1 = [[relu(1*6.1 -0.8*4.9 +0.2), relu(0.8*6.1 -1*4.9 -0.1)],\n                      [relu(1*3.1 -0.8*2.4 +0.2), relu(0.8*3.1 -1*2.4 -0.1)]]\n                   = [[2.38, 0], [1.38, 0]]\n    logits = [[-1*2.38 +0.3, 1*2.38 -0.3, 0.5*2.38],\n              [-1*1.38 +0.3, 1*1.38 -0.3, 0.5*1.38]]\n           = [[-2.08, 2.08, 1.19], [-1.08, 1.08, 0.69]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10.], [5.]],\n          expected_logits=[[-2.08, 2.08, 1.19], [-1.08, 1.08, .69]])\n\n  def test_multi_dim_input_one_dim_logits(self):\n    \"\"\"Tests multi-dimensional inputs and one-dimensional logits.\n\n    input_layer = [[10, 8]]\n    hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]\n                   = [[1.3, 0.9]]\n    hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]\n                   = [[0.78, relu(-0.06)]] = [[0.78, 0]]\n    logits = [[-1*0.78 +1*0 +0.3]] = [[-0.48]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=1,\n          inputs=[[10., 8.]],\n          expected_logits=[[-0.48]])\n\n  def test_multi_dim_input_multi_dim_logits(self):\n    \"\"\"Tests multi-dimensional inputs and multi-dimensional logits.\n\n    input_layer = [[10, 8]]\n    hidden_layer_0 = [[relu(0.6*10 -0.6*8 +0.1), relu(0.5*10 -0.5*8 -0.1)]]\n                   = [[1.3, 0.9]]\n    hidden_layer_1 = [[relu(1*1.3 -0.8*0.9 + 0.2), relu(0.8*1.3 -1*0.9 -0.2)]]\n                   = [[0.78, relu(-0.06)]] = [[0.78, 0]]\n    logits = [[-1*0.78 + 0.3, 1*0.78 -0.3, 0.5*0.78]] = [[-0.48, 0.48, 0.39]]\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      self._test_logits(\n          mode,\n          hidden_units=(2, 2),\n          logits_dimension=3,\n          inputs=[[10., 8.]],\n          expected_logits=[[-0.48, 0.48, 0.39]])\n\n  def test_multi_feature_column_multi_dim_logits(self):\n    \"\"\"Tests multiple feature columns and multi-dimensional logits.\n\n    All numbers are the same as test_multi_dim_input_multi_dim_logits. The only\n    difference is that the input consists of two 1D feature columns, instead of\n    one 2D feature column.\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[-0.48, 0.48, 0.39]]\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      with tf.Graph().as_default():\n        # Global step needed for MonitoredSession, which is in turn used to\n        # explicitly set variable weights through a checkpoint.\n        tf.compat.v1.train.create_global_step()\n        # Use a variable scope here with 'dnn', emulating the dnn model_fn, so\n        # the checkpoint naming is shared.\n        with tf.compat.v1.variable_scope('dnn'):\n          input_layer_partitioner = (\n              tf.compat.v1.min_max_variable_partitioner(\n                  max_partitions=0, min_slice_size=64 << 20))\n          logit_fn = self._dnn_logit_fn_builder(\n              units=logits_dimension,\n              hidden_units=hidden_units,\n              feature_columns=[\n                  self._fc_impl.numeric_column('age'),\n                  self._fc_impl.numeric_column('height')\n              ],\n              activation_fn=tf.nn.relu,\n              dropout=None,\n              input_layer_partitioner=input_layer_partitioner,\n              batch_norm=False)\n          logits = logit_fn(\n              features={\n                  'age': tf.constant(inputs[0]),\n                  'height': tf.constant(inputs[1])\n              },\n              mode=mode)\n          with tf.compat.v1.train.MonitoredTrainingSession(\n              checkpoint_dir=self._model_dir) as sess:\n            self.assertAllClose(expected_logits, sess.run(logits))\n\n  def test_multi_feature_column_mix_multi_dim_logits(self):\n    \"\"\"Tests multiple feature columns and multi-dimensional logits.\n\n    All numbers are the same as test_multi_dim_input_multi_dim_logits. The only\n    difference is that the input consists of two 1D feature columns, instead of\n    one 2D feature column.\n    \"\"\"\n    base_global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    hidden_units = (2, 2)\n    logits_dimension = 3\n    inputs = ([[10.]], [[8.]])\n    expected_logits = [[-0.48, 0.48, 0.39]]\n\n    for mode in [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]:\n      with tf.Graph().as_default():\n        # Global step needed for MonitoredSession, which is in turn used to\n        # explicitly set variable weights through a checkpoint.\n        tf.compat.v1.train.create_global_step()\n        # Use a variable scope here with 'dnn', emulating the dnn model_fn, so\n        # the checkpoint naming is shared.\n        with tf.compat.v1.variable_scope('dnn'):\n          input_layer_partitioner = (\n              tf.compat.v1.min_max_variable_partitioner(\n                  max_partitions=0, min_slice_size=64 << 20))\n          logit_fn = self._dnn_logit_fn_builder(\n              units=logits_dimension,\n              hidden_units=hidden_units,\n              feature_columns=[\n                  feature_column.numeric_column('age'),\n                  tf.feature_column.numeric_column('height')\n              ],\n              activation_fn=tf.nn.relu,\n              dropout=None,\n              input_layer_partitioner=input_layer_partitioner,\n              batch_norm=False)\n          logits = logit_fn(\n              features={\n                  'age': tf.constant(inputs[0]),\n                  'height': tf.constant(inputs[1])\n              },\n              mode=mode)\n          with tf.compat.v1.train.MonitoredTrainingSession(\n              checkpoint_dir=self._model_dir) as sess:\n            self.assertAllClose(expected_logits, sess.run(logits))\n\n\nclass BaseDNNWarmStartingTest(object):\n\n  def __init__(self,\n               _dnn_classifier_fn,\n               _dnn_regressor_fn,\n               fc_impl=feature_column):\n    self._dnn_classifier_fn = _dnn_classifier_fn\n    self._dnn_regressor_fn = _dnn_regressor_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    # Create a directory to save our old checkpoint and vocabularies to.\n    self._ckpt_and_vocab_dir = tempfile.mkdtemp()\n\n    # Make a dummy input_fn.\n    def _input_fn():\n      features = {\n          'city': [['Palo Alto'], ['Mountain View']],\n          'locality': [['Palo Alto'], ['Mountain View']],\n          'occupation': [['doctor'], ['consultant']]\n      }\n      return features, [0, 1]\n\n    self._input_fn = _input_fn\n\n  def tearDown(self):\n    # Clean up checkpoint / vocab dir.\n    tf.compat.v1.summary.FileWriterCache.clear()\n    shutil.rmtree(self._ckpt_and_vocab_dir)\n\n  def assertAllNotClose(self, t1, t2):\n    \"\"\"Helper assert for arrays.\"\"\"\n    sum_of_abs_diff = 0.0\n    for x, y in zip(t1, t2):\n      try:\n        for a, b in zip(x, y):\n          sum_of_abs_diff += abs(b - a)\n      except TypeError:\n        sum_of_abs_diff += abs(y - x)\n    self.assertGreater(sum_of_abs_diff, 0)\n\n  def test_classifier_basic_warm_starting(self):\n    \"\"\"Tests correctness of DNNClassifier default warm-start.\"\"\"\n    city = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNClassifier and train to save a checkpoint.\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        n_classes=4,\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        warm_start_from=dnn_classifier.model_dir)\n\n    warm_started_dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_classifier.get_variable_names():\n      self.assertAllClose(\n          dnn_classifier.get_variable_value(variable_name),\n          warm_started_dnn_classifier.get_variable_value(variable_name))\n\n  def test_regressor_basic_warm_starting(self):\n    \"\"\"Tests correctness of DNNRegressor default warm-start.\"\"\"\n    city = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNRegressor and train to save a checkpoint.\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        model_dir=self._ckpt_and_vocab_dir,\n        optimizer='SGD')\n    dnn_regressor.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNRegressor, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        warm_start_from=dnn_regressor.model_dir)\n\n    warm_started_dnn_regressor.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_regressor.get_variable_names():\n      self.assertAllClose(\n          dnn_regressor.get_variable_value(variable_name),\n          warm_started_dnn_regressor.get_variable_value(variable_name))\n\n  def test_warm_starting_selective_variables(self):\n    \"\"\"Tests selecting variables to warm-start.\"\"\"\n    city = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNClassifier and train to save a checkpoint.\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        n_classes=4,\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        # The provided regular expression will only warm-start the city\n        # embedding, not the kernels and biases of the hidden weights.\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=dnn_classifier.model_dir,\n            vars_to_warm_start='.*(city).*'))\n\n    warm_started_dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_classifier.get_variable_names():\n      if 'city' in variable_name:\n        self.assertAllClose(\n            dnn_classifier.get_variable_value(variable_name),\n            warm_started_dnn_classifier.get_variable_value(variable_name))\n      elif 'bias' in variable_name:\n        # Hidden layer biases are zero-initialized.\n        bias_values = warm_started_dnn_classifier.get_variable_value(\n            variable_name)\n        self.assertAllClose(np.zeros_like(bias_values), bias_values)\n      elif 'kernel' in variable_name:\n        # We can't override the glorot uniform initializer used for the kernels\n        # in the dense layers, so just make sure we're not getting the same\n        # values from the old checkpoint.\n        self.assertAllNotClose(\n            dnn_classifier.get_variable_value(variable_name),\n            warm_started_dnn_classifier.get_variable_value(variable_name))\n\n  def test_warm_starting_with_vocab_remapping_and_partitioning(self):\n    \"\"\"Tests warm-starting with vocab remapping and partitioning.\"\"\"\n    vocab_list = ['doctor', 'lawyer', 'consultant']\n    vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')\n    with open(vocab_file, 'w') as f:\n      f.write('\\n'.join(vocab_list))\n    occupation = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_file(\n            'occupation',\n            vocabulary_file=vocab_file,\n            vocabulary_size=len(vocab_list)),\n        dimension=2)\n\n    # Create a DNNClassifier and train to save a checkpoint.\n    partitioner = tf.compat.v1.fixed_size_partitioner(num_shards=2)\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[occupation],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD',\n        input_layer_partitioner=partitioner)\n    dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).  Use a new FeatureColumn with a\n    # different vocabulary for occupation.\n    new_vocab_list = ['doctor', 'consultant', 'engineer']\n    new_vocab_file = os.path.join(self._ckpt_and_vocab_dir,\n                                  'new_occupation_vocab')\n    with open(new_vocab_file, 'w') as f:\n      f.write('\\n'.join(new_vocab_list))\n    new_occupation = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_file(\n            'occupation',\n            vocabulary_file=new_vocab_file,\n            vocabulary_size=len(new_vocab_list)),\n        dimension=2)\n    # We can create our VocabInfo object from the new and old occupation\n    # FeatureColumn's.\n    occupation_vocab_info = estimator.VocabInfo(\n        new_vocab=new_occupation.categorical_column.vocabulary_file,\n        new_vocab_size=new_occupation.categorical_column.vocabulary_size,\n        num_oov_buckets=new_occupation.categorical_column.num_oov_buckets,\n        old_vocab=occupation.categorical_column.vocabulary_file,\n        old_vocab_size=occupation.categorical_column.vocabulary_size,\n        # Can't use constant_initializer with load_and_remap.  In practice,\n        # use a truncated normal initializer.\n        backup_initializer=tf.compat.v1.initializers.random_uniform(\n            minval=0.39, maxval=0.39))\n    warm_started_dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[occupation],\n        n_classes=4,\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=dnn_classifier.model_dir,\n            var_name_to_vocab_info={\n                OCCUPATION_EMBEDDING_NAME: occupation_vocab_info\n            },\n            # Explicitly providing None here will only warm-start variables\n            # referenced in var_name_to_vocab_info (no hidden weights will be\n            # warmstarted).\n            vars_to_warm_start=None),\n        input_layer_partitioner=partitioner)\n\n    warm_started_dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n    # 'doctor' was ID-0 and still ID-0.\n    self.assertAllClose(\n        dnn_classifier.get_variable_value(OCCUPATION_EMBEDDING_NAME)[0, :],\n        warm_started_dnn_classifier.get_variable_value(\n            OCCUPATION_EMBEDDING_NAME)[0, :])\n    # 'consultant' was ID-2 and now ID-1.\n    self.assertAllClose(\n        dnn_classifier.get_variable_value(OCCUPATION_EMBEDDING_NAME)[2, :],\n        warm_started_dnn_classifier.get_variable_value(\n            OCCUPATION_EMBEDDING_NAME)[1, :])\n    # 'engineer' is a new entry and should be initialized with the\n    # backup_initializer in VocabInfo.\n    self.assertAllClose([0.39] * 2,\n                        warm_started_dnn_classifier.get_variable_value(\n                            OCCUPATION_EMBEDDING_NAME)[2, :])\n    for variable_name in warm_started_dnn_classifier.get_variable_names():\n      if 'bias' in variable_name:\n        # Hidden layer biases are zero-initialized.\n        bias_values = warm_started_dnn_classifier.get_variable_value(\n            variable_name)\n        self.assertAllClose(np.zeros_like(bias_values), bias_values)\n      elif 'kernel' in variable_name:\n        # We can't override the glorot uniform initializer used for the kernels\n        # in the dense layers, so just make sure we're not getting the same\n        # values from the old checkpoint.\n        self.assertAllNotClose(\n            dnn_classifier.get_variable_value(variable_name),\n            warm_started_dnn_classifier.get_variable_value(variable_name))\n\n  def test_warm_starting_with_naming_change(self):\n    \"\"\"Tests warm-starting with a Tensor name remapping.\"\"\"\n    locality = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'locality', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n\n    # Create a DNNClassifier and train to save a checkpoint.\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[locality],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second DNNClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    city = self._fc_impl.embedding_column(\n        self._fc_impl.categorical_column_with_vocabulary_list(\n            'city', vocabulary_list=['Mountain View', 'Palo Alto']),\n        dimension=5)\n    warm_started_dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=[256, 128],\n        feature_columns=[city],\n        n_classes=4,\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        # The 'city' variable correspond to the 'locality' variable in the\n        # previous model.\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=dnn_classifier.model_dir,\n            var_name_to_prev_var_name={\n                CITY_EMBEDDING_NAME:\n                    CITY_EMBEDDING_NAME.replace('city', 'locality')\n            }))\n\n    warm_started_dnn_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_dnn_classifier.get_variable_names():\n      if 'city' in variable_name:\n        self.assertAllClose(\n            dnn_classifier.get_variable_value(\n                CITY_EMBEDDING_NAME.replace('city', 'locality')),\n            warm_started_dnn_classifier.get_variable_value(CITY_EMBEDDING_NAME))\n      else:\n        self.assertAllClose(\n            dnn_classifier.get_variable_value(variable_name),\n            warm_started_dnn_classifier.get_variable_value(variable_name))\n\n\nclass BaseDNNClassifierEvaluateTest(object):\n\n  def __init__(self, dnn_classifier_fn, fc_impl=feature_column):\n    self._dnn_classifier_fn = dnn_classifier_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_one_dim(self):\n    \"\"\"Asserts evaluation metrics for one-dimensional input and logits.\"\"\"\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), global_step, self._model_dir)\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age')],\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      # batch_size = 2, one false label, and one true.\n      return {'age': [[10.], [10.]]}, [[1], [0]]\n\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-2.08], [-2.08]] =>\n    # logistic = 1/(1 + exp(-logits)) = [[0.11105597], [0.11105597]]\n    # loss = -1. * log(0.111) -1. * log(0.889) = 2.31544200\n    expected_loss = 2.31544200\n    self.assertAllClose(\n        {\n            metric_keys.MetricKeys.LOSS: expected_loss,\n            metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2.,\n            metric_keys.MetricKeys.ACCURACY: 0.5,\n            metric_keys.MetricKeys.PRECISION: 0.0,\n            metric_keys.MetricKeys.RECALL: 0.0,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 0.11105597,\n            metric_keys.MetricKeys.LABEL_MEAN: 0.5,\n            metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,\n            # There is no good way to calculate AUC for only two data points.\n            # But that is what the algorithm returns.\n            metric_keys.MetricKeys.AUC: 0.5,\n            metric_keys.MetricKeys.AUC_PR: 0.75,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step\n        },\n        dnn_classifier.evaluate(input_fn=_input_fn, steps=1))\n\n  def test_multi_dim(self):\n    \"\"\"Asserts evaluation metrics for multi-dimensional input and logits.\"\"\"\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), global_step, self._model_dir)\n    n_classes = 3\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      # batch_size = 2, one false label, and one true.\n      return {'age': [[10., 8.], [10., 8.]]}, [[1], [0]]\n\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-0.48, 0.48, 0.39], [-0.48, 0.48, 0.39]]\n    # probabilities = exp(logits)/sum(exp(logits))\n    #               = [[0.16670536, 0.43538380, 0.39791084],\n    #                  [0.16670536, 0.43538380, 0.39791084]]\n    # loss = -log(0.43538380) - log(0.16670536)\n    expected_loss = 2.62305466\n    self.assertAllClose(\n        {\n            metric_keys.MetricKeys.LOSS: expected_loss,\n            metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,\n            metric_keys.MetricKeys.ACCURACY: 0.5,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step\n        }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1))\n\n  def test_float_labels(self):\n    \"\"\"Asserts evaluation metrics for float labels in binary classification.\"\"\"\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), global_step, self._model_dir)\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age')],\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      # batch_size = 2, one false label, and one true.\n      return {'age': [[10.], [10.]]}, [[0.8], [0.4]]\n\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-2.08], [-2.08]] =>\n    # logistic = 1/(1 + exp(-logits)) = [[0.11105597], [0.11105597]]\n    # loss = -0.8 * log(0.111) -0.2 * log(0.889)\n    #        -0.4 * log(0.111) -0.6 * log(0.889) = 2.7314420\n    metrics = dnn_classifier.evaluate(input_fn=_input_fn, steps=1)\n    self.assertAlmostEqual(2.7314420, metrics[metric_keys.MetricKeys.LOSS])\n\n  def test_multi_dim_weights(self):\n    \"\"\"Tests evaluation with weights.\"\"\"\n    # Uses same checkpoint with test_multi_dims\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), global_step, self._model_dir)\n    n_classes = 3\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],\n        n_classes=n_classes,\n        weight_column='w',\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      # batch_size = 2, one false label, and one true.\n      return {'age': [[10., 8.], [10., 8.]], 'w': [[10.], [100.]]}, [[1], [0]]\n\n    # Uses identical numbers as test_multi_dims\n    # See that test for calculation of logits.\n    # loss = -log(0.43538380)*10 - log(0.16670536)*100\n    expected_loss = 187.468007\n    metrics = dnn_classifier.evaluate(input_fn=_input_fn, steps=1)\n    self.assertAlmostEqual(\n        expected_loss, metrics[metric_keys.MetricKeys.LOSS], places=3)\n\n\nclass BaseDNNRegressorEvaluateTest(object):\n\n  def __init__(self, dnn_regressor_fn, fc_impl=feature_column):\n    self._dnn_regressor_fn = dnn_regressor_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_one_dim(self):\n    \"\"\"Asserts evaluation metrics for one-dimensional input and logits.\"\"\"\n    # Create checkpoint: num_inputs=1, hidden_units=(2, 2), num_outputs=1.\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), global_step, self._model_dir)\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age')],\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      return {'age': [[10.]]}, [[1.]]\n\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-2.08]] => predictions = [-2.08].\n    # loss = (1+2.08)^2 = 9.4864\n    expected_loss = 9.4864\n    self.assertAllClose(\n        {\n            metric_keys.MetricKeys.LOSS: expected_loss,\n            metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n            metric_keys.MetricKeys.PREDICTION_MEAN: -2.08,\n            metric_keys.MetricKeys.LABEL_MEAN: 1.0,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step\n        }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))\n\n  def test_multi_dim(self):\n    \"\"\"Asserts evaluation metrics for multi-dimensional input and logits.\"\"\"\n    # Create checkpoint: num_inputs=2, hidden_units=(2, 2), num_outputs=3.\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), global_step, self._model_dir)\n    label_dimension = 3\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      return {'age': [[10., 8.]]}, [[1., -1., 0.5]]\n\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-0.48, 0.48, 0.39]]\n    # loss = (1+0.48)^2 + (-1-0.48)^2 + (0.5-0.39)^2 = 4.3929\n    expected_loss = 4.3929\n    self.assertAllClose(\n        {\n            metric_keys.MetricKeys.LOSS: expected_loss,\n            metric_keys.MetricKeys.LOSS_MEAN: expected_loss / label_dimension,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 0.39 / 3.0,\n            metric_keys.MetricKeys.LABEL_MEAN: 0.5 / 3.0,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: global_step\n        }, dnn_regressor.evaluate(input_fn=_input_fn, steps=1))\n\n  def test_multi_dim_weights(self):\n    \"\"\"Asserts evaluation metrics for multi-dimensional input and logits.\"\"\"\n    # same checkpoint with test_multi_dim.\n    global_step = 100\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), global_step, self._model_dir)\n    label_dimension = 3\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],\n        label_dimension=label_dimension,\n        weight_column='w',\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      return {'age': [[10., 8.]], 'w': [10.]}, [[1., -1., 0.5]]\n\n    # Uses identical numbers as test_multi_dim.\n    # See that test for calculation of logits.\n    # loss = 4.3929*10\n    expected_loss = 43.929\n    metrics = dnn_regressor.evaluate(input_fn=_input_fn, steps=1)\n    self.assertAlmostEqual(\n        expected_loss, metrics[metric_keys.MetricKeys.LOSS], places=3)\n\n\nclass BaseDNNClassifierPredictTest(object):\n\n  def __init__(self, dnn_classifier_fn, fc_impl=feature_column):\n    self._dnn_classifier_fn = dnn_classifier_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_one_dim(self, label_vocabulary, label_output_fn):\n    \"\"\"Asserts predictions for one-dimensional input and logits.\"\"\"\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ),\n                      global_step=0,\n                      model_dir=self._model_dir)\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        label_vocabulary=label_vocabulary,\n        feature_columns=(self._fc_impl.numeric_column('x'),),\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08] =>\n    # logistic = exp(-2.08)/(1 + exp(-2.08)) = 0.11105597\n    # probabilities = [1-logistic, logistic] = [0.88894403, 0.11105597]\n    # class_ids = argmax(probabilities) = [0]\n    predictions = next(dnn_classifier.predict(input_fn=input_fn))\n    self.assertAllClose([-2.08],\n                        predictions[prediction_keys.PredictionKeys.LOGITS])\n    self.assertAllClose([0.11105597],\n                        predictions[prediction_keys.PredictionKeys.LOGISTIC])\n    self.assertAllClose(\n        [0.88894403, 0.11105597],\n        predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n    self.assertAllClose([0],\n                        predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n    self.assertAllEqual([label_output_fn(0)],\n                        predictions[prediction_keys.PredictionKeys.CLASSES])\n\n  def test_one_dim_without_label_vocabulary(self):\n    self._test_one_dim(\n        label_vocabulary=None, label_output_fn=lambda x: ('%s' % x).encode())\n\n  def test_one_dim_with_label_vocabulary(self):\n    n_classes = 2\n    self._test_one_dim(\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n  def _test_multi_dim_with_3_classes(self, label_vocabulary, label_output_fn):\n    \"\"\"Asserts predictions for multi-dimensional input and logits.\"\"\"\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ),\n                      global_step=0,\n                      model_dir=self._model_dir)\n\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=(2, 2),\n        feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),\n        label_vocabulary=label_vocabulary,\n        n_classes=3,\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        # Inputs shape is (batch_size, num_inputs).\n        x={'x': np.array([[10., 8.]])},\n        batch_size=1,\n        shuffle=False)\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-0.48, 0.48, 0.39] =>\n    # probabilities[i] = exp(logits[i]) / sum_j exp(logits[j]) =>\n    # probabilities = [0.16670536, 0.43538380, 0.39791084]\n    # class_ids = argmax(probabilities) = [1]\n    predictions = next(dnn_classifier.predict(input_fn=input_fn))\n    self.assertItemsEqual([\n        prediction_keys.PredictionKeys.LOGITS,\n        prediction_keys.PredictionKeys.PROBABILITIES,\n        prediction_keys.PredictionKeys.CLASS_IDS,\n        prediction_keys.PredictionKeys.CLASSES,\n        prediction_keys.PredictionKeys.ALL_CLASS_IDS,\n        prediction_keys.PredictionKeys.ALL_CLASSES\n    ], six.iterkeys(predictions))\n    self.assertAllClose([-0.48, 0.48, 0.39],\n                        predictions[prediction_keys.PredictionKeys.LOGITS])\n    self.assertAllClose(\n        [0.16670536, 0.43538380, 0.39791084],\n        predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n    self.assertAllEqual([1],\n                        predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n    self.assertAllEqual([label_output_fn(1)],\n                        predictions[prediction_keys.PredictionKeys.CLASSES])\n\n  def test_multi_dim_with_3_classes_but_no_label_vocab(self):\n    self._test_multi_dim_with_3_classes(\n        label_vocabulary=None, label_output_fn=lambda x: ('%s' % x).encode())\n\n  def test_multi_dim_with_3_classes_and_label_vocab(self):\n    n_classes = 3\n    self._test_multi_dim_with_3_classes(\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n\nclass BaseDNNRegressorPredictTest(object):\n\n  def __init__(self, dnn_regressor_fn, fc_impl=feature_column):\n    self._dnn_regressor_fn = dnn_regressor_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_one_dim(self):\n    \"\"\"Asserts predictions for one-dimensional input and logits.\"\"\"\n    # Create checkpoint: num_inputs=1, hidden_units=(2, 2), num_outputs=1.\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ),\n                      global_step=0,\n                      model_dir=self._model_dir)\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=(self._fc_impl.numeric_column('x'),),\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)\n    # Uses identical numbers as DNNModelTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-2.08]] => predictions = [-2.08].\n    self.assertAllClose({\n        prediction_keys.PredictionKeys.PREDICTIONS: [-2.08],\n    }, next(dnn_regressor.predict(input_fn=input_fn)))\n\n  def test_multi_dim(self):\n    \"\"\"Asserts predictions for multi-dimensional input and logits.\"\"\"\n    # Create checkpoint: num_inputs=2, hidden_units=(2, 2), num_outputs=3.\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), 100, self._model_dir)\n\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=(2, 2),\n        feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),\n        label_dimension=3,\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        # Inputs shape is (batch_size, num_inputs).\n        x={'x': np.array([[10., 8.]])},\n        batch_size=1,\n        shuffle=False)\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-0.48, 0.48, 0.39]] => predictions = [-0.48, 0.48, 0.39]\n    self.assertAllClose(\n        {\n            prediction_keys.PredictionKeys.PREDICTIONS: [-0.48, 0.48, 0.39],\n        }, next(dnn_regressor.predict(input_fn=input_fn)))\n\n\nclass _SummaryHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Saves summaries every N steps.\"\"\"\n\n  def __init__(self):\n    self._summaries = []\n\n  def begin(self):\n    self._summary_op = tf.compat.v1.summary.merge_all()\n\n  def before_run(self, run_context):\n    return tf.compat.v1.train.SessionRunArgs({'summary': self._summary_op})\n\n  def after_run(self, run_context, run_values):\n    s = tf.compat.v1.summary.Summary()\n    s.ParseFromString(run_values.results['summary'])\n    self._summaries.append(s)\n\n  def summaries(self):\n    return tuple(self._summaries)\n\n\ndef _assert_checkpoint(testcase, global_step, input_units, hidden_units,\n                       output_units, model_dir):\n  \"\"\"Asserts checkpoint contains expected variables with proper shapes.\n\n  Args:\n    testcase: A TestCase instance.\n    global_step: Expected global step value.\n    input_units: The dimension of input layer.\n    hidden_units: Iterable of integer sizes for the hidden layers.\n    output_units: The dimension of output layer (logits).\n    model_dir: The model directory.\n  \"\"\"\n  shapes = {name: shape for (name, shape) in tf.train.list_variables(model_dir)}\n\n  # Global step.\n  testcase.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n  testcase.assertEqual(\n      global_step,\n      tf.train.load_variable(model_dir, tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n  # Hidden layer weights.\n  prev_layer_units = input_units\n  for i in range(len(hidden_units)):\n    layer_units = hidden_units[i]\n    testcase.assertAllEqual((prev_layer_units, layer_units),\n                            shapes[HIDDEN_WEIGHTS_NAME_PATTERN % i])\n    testcase.assertAllEqual((layer_units,),\n                            shapes[HIDDEN_BIASES_NAME_PATTERN % i])\n    prev_layer_units = layer_units\n\n  # Output layer weights.\n  testcase.assertAllEqual((prev_layer_units, output_units),\n                          shapes[LOGITS_WEIGHTS_NAME])\n  testcase.assertAllEqual((output_units,), shapes[LOGITS_BIASES_NAME])\n\n\ndef _assert_simple_summary(testcase, expected_values, actual_summary):\n  \"\"\"Assert summary the specified simple values.\n\n  Args:\n    testcase: A TestCase instance.\n    expected_values: Dict of expected tags and simple values.\n    actual_summary: `summary_pb2.Summary`.\n  \"\"\"\n  testcase.assertAllClose(\n      expected_values, {\n          v.tag: v.simple_value\n          for v in actual_summary.value\n          if (v.tag in expected_values)\n      })\n\n\nclass BaseDNNClassifierTrainTest(object):\n\n  def __init__(self, dnn_classifier_fn, fc_impl=feature_column):\n    self._dnn_classifier_fn = dnn_classifier_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_from_scratch_with_default_optimizer_binary(self):\n    hidden_units = (2, 2)\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate final checkpoint.\n    num_steps = 5\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1]]), steps=num_steps)\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n\n  def test_from_scratch_with_default_optimizer_multi_class(self):\n    hidden_units = (2, 2)\n    n_classes = 3\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate final checkpoint.\n    num_steps = 5\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[2]]), steps=num_steps)\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=n_classes,\n        model_dir=self._model_dir)\n\n  def test_from_scratch_validate_summary(self):\n    hidden_units = (2, 2)\n    opt = mock_optimizer(self, hidden_units=hidden_units)\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n    self.assertEqual(0, opt.minimize.call_count)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(1, opt.minimize.call_count)\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      summary_keys = [v.tag for v in summary.value]\n      self.assertIn(metric_keys.MetricKeys.LOSS, summary_keys)\n      self.assertIn(metric_keys.MetricKeys.LOSS_MEAN, summary_keys)\n\n  def test_binary_classification(self):\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    # Uses identical numbers as DNNModelFnTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08] => probabilities = [0.889, 0.111]\n    # loss = -1. * log(0.111) = 2.19772100\n    expected_loss = 2.19772100\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n    self.assertEqual(0, opt.minimize.call_count)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(1, opt.minimize.call_count)\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      _assert_simple_summary(\n          self, {\n              metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n              'dnn/dnn/hiddenlayer_0/fraction_of_zero_values': 0.,\n              'dnn/dnn/hiddenlayer_1/fraction_of_zero_values': .5,\n              'dnn/dnn/logits/fraction_of_zero_values': 0.,\n              metric_keys.MetricKeys.LOSS: expected_loss,\n          }, summary)\n    _assert_checkpoint(\n        self,\n        base_global_step + num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n\n  def test_binary_classification_float_labels(self):\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    # Uses identical numbers as DNNModelFnTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08] => probabilities = [0.889, 0.111]\n    # loss = -0.8 * log(0.111) -0.2 * log(0.889) = 1.7817210\n    expected_loss = 1.7817210\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_classifier = self._dnn_classifier_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n    self.assertEqual(0, opt.minimize.call_count)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[0.8]]), steps=num_steps)\n    self.assertEqual(1, opt.minimize.call_count)\n\n  def test_multi_class(self):\n    n_classes = 3\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n\n    # Uses identical numbers as DNNModelFnTest.test_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08, 2.08, 1.19] => probabilities = [0.0109, 0.7011, 0.2879]\n    # loss = -1. * log(0.7011) = 0.35505795\n    expected_loss = 0.35505795\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_classifier = self._dnn_classifier_fn(\n        n_classes=n_classes,\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n    self.assertEqual(0, opt.minimize.call_count)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_classifier.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(1, opt.minimize.call_count)\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      _assert_simple_summary(\n          self, {\n              metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n              'dnn/dnn/hiddenlayer_0/fraction_of_zero_values': 0.,\n              'dnn/dnn/hiddenlayer_1/fraction_of_zero_values': .5,\n              'dnn/dnn/logits/fraction_of_zero_values': 0.,\n              metric_keys.MetricKeys.LOSS: expected_loss,\n          }, summary)\n    _assert_checkpoint(\n        self,\n        base_global_step + num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=n_classes,\n        model_dir=self._model_dir)\n\n\nclass BaseDNNRegressorTrainTest(object):\n\n  def __init__(self, dnn_regressor_fn, fc_impl=feature_column):\n    self._dnn_regressor_fn = dnn_regressor_fn\n    self._fc_impl = fc_impl\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_from_scratch_with_default_optimizer(self):\n    hidden_units = (2, 2)\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        model_dir=self._model_dir)\n\n    # Train for a few steps, then validate final checkpoint.\n    num_steps = 5\n    dnn_regressor.train(\n        input_fn=lambda: ({\n            'age': ((1,),)\n        }, ((10,),)), steps=num_steps)\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n\n  def test_from_scratch(self):\n    hidden_units = (2, 2)\n    opt = mock_optimizer(self, hidden_units=hidden_units)\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n    self.assertEqual(0, opt.minimize.call_count)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_regressor.train(\n        input_fn=lambda: ({\n            'age': ((1,),)\n        }, ((5.,),)),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(1, opt.minimize.call_count)\n    _assert_checkpoint(\n        self,\n        num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      summary_keys = [v.tag for v in summary.value]\n      self.assertIn(metric_keys.MetricKeys.LOSS, summary_keys)\n      self.assertIn(metric_keys.MetricKeys.LOSS_MEAN, summary_keys)\n\n  def test_one_dim(self):\n    \"\"\"Asserts train loss for one-dimensional input and logits.\"\"\"\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1.], [1.]], [.3]),\n    ), base_global_step, self._model_dir)\n\n    # Uses identical numbers as DNNModelFnTest.test_one_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [-2.08] => predictions = [-2.08]\n    # loss = (1 + 2.08)^2 = 9.4864\n    expected_loss = 9.4864\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=hidden_units,\n        feature_columns=(self._fc_impl.numeric_column('age'),),\n        optimizer=opt,\n        model_dir=self._model_dir)\n    self.assertEqual(0, opt.minimize.call_count)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_regressor.train(\n        input_fn=lambda: ({\n            'age': [[10.]]\n        }, [[1.]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(1, opt.minimize.call_count)\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      _assert_simple_summary(\n          self, {\n              metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n              'dnn/dnn/hiddenlayer_0/fraction_of_zero_values': 0.,\n              'dnn/dnn/hiddenlayer_1/fraction_of_zero_values': 0.5,\n              'dnn/dnn/logits/fraction_of_zero_values': 0.,\n              metric_keys.MetricKeys.LOSS: expected_loss,\n          }, summary)\n    _assert_checkpoint(\n        self,\n        base_global_step + num_steps,\n        input_units=1,\n        hidden_units=hidden_units,\n        output_units=1,\n        model_dir=self._model_dir)\n\n  def test_multi_dim(self):\n    \"\"\"Asserts train loss for multi-dimensional input and logits.\"\"\"\n    base_global_step = 100\n    hidden_units = (2, 2)\n    create_checkpoint((\n        ([[.6, .5], [-.6, -.5]], [.1, -.1]),\n        ([[1., .8], [-.8, -1.]], [.2, -.2]),\n        ([[-1., 1., .5], [-1., 1., .5]], [.3, -.3, .0]),\n    ), base_global_step, self._model_dir)\n    input_dimension = 2\n    label_dimension = 3\n\n    # Uses identical numbers as\n    # DNNModelFnTest.test_multi_dim_input_multi_dim_logits.\n    # See that test for calculation of logits.\n    # logits = [[-0.48, 0.48, 0.39]]\n    # loss = (1+0.48)^2 + (-1-0.48)^2 + (0.5-0.39)^2 = 4.3929\n    expected_loss = 4.3929\n    opt = mock_optimizer(\n        self, hidden_units=hidden_units, expected_loss=expected_loss)\n    dnn_regressor = self._dnn_regressor_fn(\n        hidden_units=hidden_units,\n        feature_columns=[\n            self._fc_impl.numeric_column('age', shape=[input_dimension])\n        ],\n        label_dimension=label_dimension,\n        optimizer=opt,\n        model_dir=self._model_dir)\n    self.assertEqual(0, opt.minimize.call_count)\n\n    # Train for a few steps, then validate optimizer, summaries, and\n    # checkpoint.\n    num_steps = 5\n    summary_hook = _SummaryHook()\n    dnn_regressor.train(\n        input_fn=lambda: ({\n            'age': [[10., 8.]]\n        }, [[1., -1., 0.5]]),\n        steps=num_steps,\n        hooks=(summary_hook,))\n    self.assertEqual(1, opt.minimize.call_count)\n    summaries = summary_hook.summaries()\n    self.assertEqual(num_steps, len(summaries))\n    for summary in summaries:\n      _assert_simple_summary(\n          self, {\n              metric_keys.MetricKeys.LOSS_MEAN: expected_loss / label_dimension,\n              'dnn/dnn/hiddenlayer_0/fraction_of_zero_values': 0.,\n              'dnn/dnn/hiddenlayer_1/fraction_of_zero_values': 0.5,\n              'dnn/dnn/logits/fraction_of_zero_values': 0.,\n              metric_keys.MetricKeys.LOSS: expected_loss,\n          }, summary)\n    _assert_checkpoint(\n        self,\n        base_global_step + num_steps,\n        input_units=input_dimension,\n        hidden_units=hidden_units,\n        output_units=label_dimension,\n        model_dir=self._model_dir)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/linear_estimator_test_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for LinearEstimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import linear\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.canned.v1 import linear_testing_utils_v1\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n\ndef _linear_estimator_fn(weight_column=None, label_dimension=1, **kwargs):\n  \"\"\"Returns a LinearEstimator that uses regression_head.\"\"\"\n  return linear.LinearEstimator(\n      head=head_lib._regression_head(\n          weight_column=weight_column,\n          label_dimension=label_dimension,\n          # Tests in core (from which this test inherits) test the sum loss.\n          loss_reduction=tf.compat.v1.losses.Reduction.SUM),\n      **kwargs)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearEstimatorEvaluateTest(\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearEstimatorPredictTest(\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearEstimatorTrainTest(\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_estimator_fn)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearEstimatorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, batch_size):\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = linear.LinearEstimator(\n        head=head_lib._regression_head(label_dimension=label_dimension),\n        feature_columns=feature_columns,\n        model_dir=self._model_dir)\n\n    # Train\n    num_steps = 10\n    est.train(train_input_fn, steps=num_steps)\n\n    # Evaluate\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(scores))\n\n    # Predict\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    # Export\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/linear_test_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for linear.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.canned import linear\nfrom tensorflow_estimator.python.estimator.canned.v1 import linear_testing_utils_v1\n\n\ndef _linear_regressor_fn(*args, **kwargs):\n  return linear.LinearRegressor(*args, **kwargs)\n\n\ndef _linear_classifier_fn(*args, **kwargs):\n  return linear.LinearClassifier(*args, **kwargs)\n\n\n# Tests for Linear Regressor.\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorPartitionerTest(\n    linear_testing_utils_v1.BaseLinearRegressorPartitionerTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPartitionerTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorPartitionerV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorPartitionerTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPartitionerTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorEvaluationTest(\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorEvaluationV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorEvaluationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorPredictTest(\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorPredictV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorPredictTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorIntegrationTest(\n    linear_testing_utils_v1.BaseLinearRegressorIntegrationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorIntegrationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorIntegrationV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorIntegrationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorIntegrationTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorTrainingTest(\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearRegressorTrainingV2Test(\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearRegressorTrainingTest.__init__(\n        self, _linear_regressor_fn, fc_lib=feature_column_v2)\n\n\n# Tests for Linear Classifier.\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearClassifierTrainingTest(\n    linear_testing_utils_v1.BaseLinearClassifierTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierTrainingTest.__init__(\n        self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearClassifierTrainingV2Test(\n    linear_testing_utils_v1.BaseLinearClassifierTrainingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierTrainingTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearClassifierEvaluationTest(\n    linear_testing_utils_v1.BaseLinearClassifierEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierEvaluationTest.__init__(\n        self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearClassifierEvaluationV2Test(\n    linear_testing_utils_v1.BaseLinearClassifierEvaluationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierEvaluationTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearClassifierPredictTest(\n    linear_testing_utils_v1.BaseLinearClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierPredictTest.__init__(\n        self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearClassifierPredictV2Test(\n    linear_testing_utils_v1.BaseLinearClassifierPredictTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierPredictTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearClassifierIntegrationTest(\n    linear_testing_utils_v1.BaseLinearClassifierIntegrationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierIntegrationTest.__init__(\n        self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearClassifierIntegrationV2Test(\n    linear_testing_utils_v1.BaseLinearClassifierIntegrationTest,\n    tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearClassifierIntegrationTest.__init__(\n        self,\n        linear_classifier_fn=_linear_classifier_fn,\n        fc_lib=feature_column_v2)\n\n\n# Tests for Linear logit_fn.\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearLogitFnTest(linear_testing_utils_v1.BaseLinearLogitFnTest,\n                        tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearLogitFnTest.__init__(\n        self, fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearLogitFnV2Test(linear_testing_utils_v1.BaseLinearLogitFnTest,\n                          tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearLogitFnTest.__init__(\n        self, fc_lib=feature_column_v2)\n\n\n# Tests for warm-starting with Linear logit_fn.\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearWarmStartingTest(linear_testing_utils_v1.BaseLinearWarmStartingTest,\n                             tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearWarmStartingTest.__init__(\n        self,\n        _linear_classifier_fn,\n        _linear_regressor_fn,\n        fc_lib=feature_column)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass LinearWarmStartingV2Test(\n    linear_testing_utils_v1.BaseLinearWarmStartingTest, tf.test.TestCase):\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    linear_testing_utils_v1.BaseLinearWarmStartingTest.__init__(\n        self,\n        _linear_classifier_fn,\n        _linear_regressor_fn,\n        fc_lib=feature_column_v2)\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass ComputeFractionOfZeroTest(tf.test.TestCase):\n\n  def _assertSparsity(self, expected_sparsity, tensor):\n    sparsity = linear._compute_fraction_of_zero([tensor])\n    with self.test_session() as sess:\n      self.assertAllClose(expected_sparsity, sess.run(sparsity))\n\n  def test_small_float32(self):\n    self._assertSparsity(\n        0.75, ops.convert_to_tensor([0, 0, 0, 1], dtype=tf.dtypes.float32))\n    self._assertSparsity(\n        0.5, ops.convert_to_tensor([0, 1, 0, 1], dtype=tf.dtypes.float32))\n\n  def test_small_int32(self):\n    self._assertSparsity(\n        0.75, ops.convert_to_tensor([0, 0, 0, 1], dtype=tf.dtypes.int32))\n\n  def test_small_float64(self):\n    self._assertSparsity(\n        0.75, ops.convert_to_tensor([0, 0, 0, 1], dtype=tf.dtypes.float64))\n\n  def test_small_int64(self):\n    self._assertSparsity(\n        0.75, ops.convert_to_tensor([0, 0, 0, 1], dtype=tf.dtypes.int64))\n\n  def test_nested(self):\n    self._assertSparsity(\n        0.75, [ops.convert_to_tensor([0, 0]),\n               ops.convert_to_tensor([0, 1])])\n\n  def test_none(self):\n    with self.assertRaises(ValueError):\n      linear._compute_fraction_of_zero([])\n\n  def test_empty(self):\n    sparsity = linear._compute_fraction_of_zero([ops.convert_to_tensor([])])\n    with self.test_session() as sess:\n      sparsity_np = sess.run(sparsity)\n      self.assertTrue(\n          np.isnan(sparsity_np), 'Expected sparsity=nan, got %s' % sparsity_np)\n\n  def test_multiple_empty(self):\n    sparsity = linear._compute_fraction_of_zero([\n        ops.convert_to_tensor([]),\n        ops.convert_to_tensor([]),\n    ])\n    with self.test_session() as sess:\n      sparsity_np = sess.run(sparsity)\n      self.assertTrue(\n          np.isnan(sparsity_np), 'Expected sparsity=nan, got %s' % sparsity_np)\n\n  def test_some_empty(self):\n    with self.test_session():\n      self._assertSparsity(0.5, [\n          ops.convert_to_tensor([]),\n          ops.convert_to_tensor([0.]),\n          ops.convert_to_tensor([1.]),\n      ])\n\n  def test_mixed_types(self):\n    with self.test_session():\n      self._assertSparsity(0.6, [\n          ops.convert_to_tensor([0, 0, 1, 1, 1], dtype=tf.dtypes.float32),\n          ops.convert_to_tensor([0, 0, 0, 0, 1], dtype=tf.dtypes.int32),\n      ])\n\n  def test_2_27_zeros__using_512_MiB_of_ram(self):\n    self._assertSparsity(1., tf.zeros([int(2**27 * 1.01)],\n                                      dtype=tf.dtypes.int8))\n\n  def test_2_27_ones__using_512_MiB_of_ram(self):\n    self._assertSparsity(0., tf.ones([int(2**27 * 1.01)], dtype=tf.dtypes.int8))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/canned/v1/linear_testing_utils_v1.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utils for testing linear estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport math\nimport os\nimport shutil\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow.python.feature_column import feature_column_v2\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import variables as variables_lib\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator import run_config\nfrom tensorflow_estimator.python.estimator.canned import linear\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n# pylint rules which are disabled by default for test files.\n# pylint: disable=invalid-name,protected-access,missing-docstring\n\n# Names of variables created by model.\nAGE_WEIGHT_NAME = 'linear/linear_model/age/weights'\nHEIGHT_WEIGHT_NAME = 'linear/linear_model/height/weights'\nOCCUPATION_WEIGHT_NAME = 'linear/linear_model/occupation/weights'\nBIAS_NAME = 'linear/linear_model/bias_weights'\nLANGUAGE_WEIGHT_NAME = 'linear/linear_model/language/weights'\n\n# This is so that we can easily switch between feature_column and\n# feature_column_v2 for testing.\nfeature_column.numeric_column = feature_column._numeric_column\nfeature_column.categorical_column_with_hash_bucket = feature_column._categorical_column_with_hash_bucket  # pylint: disable=line-too-long\nfeature_column.categorical_column_with_vocabulary_list = feature_column._categorical_column_with_vocabulary_list  # pylint: disable=line-too-long\nfeature_column.categorical_column_with_vocabulary_file = feature_column._categorical_column_with_vocabulary_file  # pylint: disable=line-too-long\nfeature_column.embedding_column = feature_column._embedding_column\n\n\ndef assert_close(expected, actual, rtol=1e-04, name='assert_close'):\n  with ops.name_scope(name, 'assert_close', (expected, actual, rtol)) as scope:\n    expected = ops.convert_to_tensor(expected, name='expected')\n    actual = ops.convert_to_tensor(actual, name='actual')\n    rdiff = tf.math.abs(expected - actual, 'diff') / tf.math.abs(expected)\n    rtol = ops.convert_to_tensor(rtol, name='rtol')\n    return tf.compat.v1.debugging.assert_less(\n        rdiff,\n        rtol,\n        data=('Condition expected =~ actual did not hold element-wise:'\n              'expected = ', expected, 'actual = ', actual, 'rdiff = ', rdiff,\n              'rtol = ', rtol,),\n        name=scope)\n\n\ndef save_variables_to_ckpt(model_dir):\n  init_all_op = [tf.compat.v1.initializers.global_variables()]\n  with tf.compat.v1.Session() as sess:\n    sess.run(init_all_op)\n    tf.compat.v1.train.Saver().save(sess, os.path.join(model_dir, 'model.ckpt'))\n\n\ndef queue_parsed_features(feature_map):\n  tensors_to_enqueue = []\n  keys = []\n  for key, tensor in six.iteritems(feature_map):\n    keys.append(key)\n    tensors_to_enqueue.append(tensor)\n  queue_dtypes = [x.dtype for x in tensors_to_enqueue]\n  input_queue = tf.queue.FIFOQueue(capacity=100, dtypes=queue_dtypes)\n  tf.compat.v1.train.queue_runner.add_queue_runner(\n      tf.compat.v1.train.queue_runner.QueueRunner(\n          input_queue, [input_queue.enqueue(tensors_to_enqueue)]))\n  dequeued_tensors = input_queue.dequeue()\n  return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}\n\n\ndef sorted_key_dict(unsorted_dict):\n  return {k: unsorted_dict[k] for k in sorted(unsorted_dict)}\n\n\ndef sigmoid(x):\n  return 1 / (1 + np.exp(-1.0 * x))\n\n\nclass CheckPartitionerVarHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"A `SessionRunHook` to check a partitioned variable.\"\"\"\n\n  def __init__(self, test_case, var_name, var_dim, partitions):\n    self._test_case = test_case\n    self._var_name = var_name\n    self._var_dim = var_dim\n    self._partitions = partitions\n\n  def begin(self):\n    with tf.compat.v1.variable_scope(\n        tf.compat.v1.get_variable_scope()) as scope:\n      scope.reuse_variables()\n      partitioned_weight = tf.compat.v1.get_variable(\n          self._var_name, shape=(self._var_dim, 1))\n      self._test_case.assertTrue(\n          isinstance(partitioned_weight, variables_lib.PartitionedVariable))\n      for part in partitioned_weight:\n        self._test_case.assertEqual(self._var_dim // self._partitions,\n                                    part.get_shape()[0])\n\n\nclass BaseLinearRegressorPartitionerTest(object):\n\n  def __init__(self, linear_regressor_fn, fc_lib=feature_column):\n    self._linear_regressor_fn = linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def testPartitioner(self):\n    x_dim = 64\n    partitions = 4\n\n    def _partitioner(shape, dtype):\n      del dtype  # unused; required by Fn signature.\n      # Only partition the embedding tensor.\n      return [partitions, 1] if shape[0] == x_dim else [1]\n\n    regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(\n            'language', hash_bucket_size=x_dim),),\n        partitioner=_partitioner,\n        model_dir=self._model_dir)\n\n    def _input_fn():\n      return {\n          'language':\n              tf.sparse.SparseTensor(\n                  values=['english', 'spanish'],\n                  indices=[[0, 0], [0, 1]],\n                  dense_shape=[1, 2])\n      }, [[10.]]\n\n    hook = CheckPartitionerVarHook(self, LANGUAGE_WEIGHT_NAME, x_dim,\n                                   partitions)\n    regressor.train(input_fn=_input_fn, steps=1, hooks=[hook])\n\n  def testDefaultPartitionerWithMultiplePsReplicas(self):\n    partitions = 2\n    # This results in weights larger than the default partition size of 64M,\n    # so partitioned weights are created (each weight uses 4 bytes).\n    x_dim = 32 << 20\n\n    class FakeRunConfig(run_config.RunConfig):\n\n      @property\n      def num_ps_replicas(self):\n        return partitions\n\n    # Mock the device setter as ps is not available on test machines.\n    with tf.compat.v1.test.mock.patch.object(\n        estimator,\n        '_get_replica_device_setter',\n        return_value=lambda _: '/cpu:0'):\n      linear_regressor = self._linear_regressor_fn(\n          feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(\n              'language', hash_bucket_size=x_dim),),\n          config=FakeRunConfig(),\n          model_dir=self._model_dir)\n\n      def _input_fn():\n        return {\n            'language':\n                tf.sparse.SparseTensor(\n                    values=['english', 'spanish'],\n                    indices=[[0, 0], [0, 1]],\n                    dense_shape=[1, 2])\n        }, [[10.]]\n\n      hook = CheckPartitionerVarHook(self, LANGUAGE_WEIGHT_NAME, x_dim,\n                                     partitions)\n      linear_regressor.train(input_fn=_input_fn, steps=1, hooks=[hook])\n\n\n# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.\nclass BaseLinearRegressorEvaluationTest(object):\n\n  def __init__(self, linear_regressor_fn, fc_lib=feature_column):\n    self._linear_regressor_fn = linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_evaluation_for_simple_data(self):\n    with tf.Graph().as_default():\n      tf.Variable([[11.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([2.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir)\n    eval_metrics = linear_regressor.evaluate(\n        input_fn=lambda: ({\n            'age': ((1,),)\n        }, ((10.,),)), steps=1)\n\n    # Logit is (1. * 11.0 + 2.0) = 13, while label is 10. Loss is 3**2 = 9.\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 9.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_batch(self):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[11.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([2.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir)\n    eval_metrics = linear_regressor.evaluate(\n        input_fn=lambda: ({\n            'age': ((1,), (1,))\n        }, ((10.,), (10.,))), steps=1)\n\n    # Logit is (1. * 11.0 + 2.0) = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the sum over batch = 9 + 9 = 18\n    # Average loss is the average over batch = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 18.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_weights(self):\n    \"\"\"Tests evaluation with weights.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[11.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([2.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      features = {'age': ((1,), (1,)), 'weights': ((1.,), (2.,))}\n      labels = ((10.,), (10.,))\n      return features, labels\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        weight_column='weights',\n        model_dir=self._model_dir)\n    eval_metrics = linear_regressor.evaluate(input_fn=_input_fn, steps=1)\n\n    # Logit is (1. * 11.0 + 2.0) = 13, while label is 10.\n    # Loss per example is 3**2 = 9.\n    # Training loss is the weighted sum over batch = 9 + 2*9 = 27\n    # average loss is the weighted average = 9 + 2*9 / (1 + 2) = 9\n    self.assertDictEqual(\n        {\n            metric_keys.MetricKeys.LOSS: 27.,\n            metric_keys.MetricKeys.LOSS_MEAN: 9.,\n            metric_keys.MetricKeys.PREDICTION_MEAN: 13.,\n            metric_keys.MetricKeys.LABEL_MEAN: 10.,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP: 100\n        }, eval_metrics)\n\n  def test_evaluation_for_multi_dimensions(self):\n    x_dim = 3\n    label_dim = 2\n    with tf.Graph().as_default():\n      tf.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([7.0, 8.0], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age', shape=(x_dim,)),),\n        label_dimension=label_dim,\n        model_dir=self._model_dir)\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': np.array([[2., 4., 5.]]),\n        },\n        y=np.array([[46., 58.]]),\n        batch_size=1,\n        num_epochs=None,\n        shuffle=False)\n    eval_metrics = linear_regressor.evaluate(input_fn=input_fn, steps=1)\n\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is\n    #   [2., 4., 5.] * [1.0, 2.0] + [7.0, 8.0] = [39, 50] + [7.0, 8.0]\n    #                  [3.0, 4.0]\n    #                  [5.0, 6.0]\n    # which is [46, 58]\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n  def test_evaluation_for_multiple_feature_columns(self):\n    with tf.Graph().as_default():\n      tf.Variable([[10.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([[2.0]], name=HEIGHT_WEIGHT_NAME)\n      tf.Variable([5.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    batch_size = 2\n    feature_columns = [\n        self._fc_lib.numeric_column('age'),\n        self._fc_lib.numeric_column('height')\n    ]\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': np.array([20, 40]),\n            'height': np.array([4, 8])\n        },\n        y=np.array([[213.], [421.]]),\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=False)\n\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns, model_dir=self._model_dir)\n\n    eval_metrics = est.evaluate(input_fn=input_fn, steps=1)\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is [(20. * 10.0 + 4 * 2.0 + 5.0), (40. * 10.0 + 8 * 2.0 + 5.0)] =\n    # [213.0, 421.0], while label is [213., 421.]. Loss = 0.\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n  def test_evaluation_for_multiple_feature_columns_mix(self):\n    with tf.Graph().as_default():\n      tf.Variable([[10.0]], name=AGE_WEIGHT_NAME)\n      tf.Variable([[2.0]], name=HEIGHT_WEIGHT_NAME)\n      tf.Variable([5.0], name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    batch_size = 2\n    feature_columns = [\n        feature_column.numeric_column('age'),\n        tf.feature_column.numeric_column('height')\n    ]\n\n    def _input_fn():\n      features_ds = tf.compat.v1.data.Dataset.from_tensor_slices({\n          'age': np.array([20, 40]),\n          'height': np.array([4, 8])\n      })\n      labels_ds = tf.compat.v1.data.Dataset.from_tensor_slices(\n          np.array([[213.], [421.]]))\n      return (tf.compat.v1.data.Dataset.zip(\n          (features_ds, labels_ds)).batch(batch_size).repeat(None))\n\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns, model_dir=self._model_dir)\n\n    eval_metrics = est.evaluate(input_fn=_input_fn, steps=1)\n    self.assertItemsEqual(\n        (metric_keys.MetricKeys.LOSS, metric_keys.MetricKeys.LOSS_MEAN,\n         metric_keys.MetricKeys.PREDICTION_MEAN,\n         metric_keys.MetricKeys.LABEL_MEAN, tf.compat.v1.GraphKeys.GLOBAL_STEP),\n        eval_metrics.keys())\n\n    # Logit is [(20. * 10.0 + 4 * 2.0 + 5.0), (40. * 10.0 + 8 * 2.0 + 5.0)] =\n    # [213.0, 421.0], while label is [213., 421.]. Loss = 0.\n    self.assertAlmostEqual(0, eval_metrics[metric_keys.MetricKeys.LOSS])\n\n\nclass BaseLinearRegressorPredictTest(object):\n\n  def __init__(self, linear_regressor_fn, fc_lib=feature_column):\n    self._linear_regressor_fn = linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def test_1d(self):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[10.]], name='linear/linear_model/x/weights')\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('x'),),\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': np.array([[2.]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = linear_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x * weight + bias = 2. * 10. + .2 = 20.2\n    self.assertAllClose([[20.2]], predicted_scores)\n\n  def testMultiDim(self):\n    \"\"\"Tests predict when all variables are multi-dimenstional.\"\"\"\n    batch_size = 2\n    label_dimension = 3\n    x_dim = 4\n    feature_columns = (self._fc_lib.numeric_column('x', shape=(x_dim,)),)\n    with tf.Graph().as_default():\n      tf.Variable(  # shape=[x_dim, label_dimension]\n          [[1., 2., 3.], [2., 3., 4.], [3., 4., 5.], [4., 5., 6.]],\n          name='linear/linear_model/x/weights')\n      tf.Variable(  # shape=[label_dimension]\n          [.2, .4, .6], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        # x shape=[batch_size, x_dim]\n        x={'x': np.array([[1., 2., 3., 4.], [5., 6., 7., 8.]])},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predictions = linear_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # score = x * weight + bias, shape=[batch_size, label_dimension]\n    self.assertAllClose([[30.2, 40.4, 50.6], [70.2, 96.4, 122.6]],\n                        predicted_scores)\n\n  def testTwoFeatureColumns(self):\n    \"\"\"Tests predict with two feature columns.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[10.]], name='linear/linear_model/x0/weights')\n      tf.Variable([[20.]], name='linear/linear_model/x1/weights')\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('x0'),\n                         self._fc_lib.numeric_column('x1')),\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'x0': np.array([[2.]]),\n            'x1': np.array([[3.]])\n        },\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = linear_regressor.predict(input_fn=predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2\n    self.assertAllClose([[80.2]], predicted_scores)\n\n  def testTwoFeatureColumnsMix(self):\n    \"\"\"Tests predict with two feature columns.\"\"\"\n    with tf.Graph().as_default():\n      tf.Variable([[10.]], name='linear/linear_model/x0/weights')\n      tf.Variable([[20.]], name='linear/linear_model/x1/weights')\n      tf.Variable([.2], name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(feature_column.numeric_column('x0'),\n                         tf.feature_column.numeric_column('x1')),\n        model_dir=self._model_dir)\n\n    def _predict_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensor_slices({\n          'x0': np.array([[2.]]),\n          'x1': np.array([[3.]])\n      }).batch(1)\n\n    predictions = linear_regressor.predict(input_fn=_predict_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    # x0 * weight0 + x1 * weight1 + bias = 2. * 10. + 3. * 20 + .2 = 80.2\n    self.assertAllClose([[80.2]], predicted_scores)\n\n  def testSparseCombiner(self):\n    w_a = 2.0\n    w_b = 3.0\n    w_c = 5.0\n    bias = 5.0\n    with tf.Graph().as_default():\n      tf.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME)\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          1, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      return tf.compat.v1.data.Dataset.from_tensors({\n          'language':\n              tf.sparse.SparseTensor(\n                  values=['a', 'c', 'b', 'c'],\n                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],\n                  dense_shape=[2, 2]),\n      })\n\n    feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(\n        'language', vocabulary_list=['a', 'b', 'c']),)\n\n    # Check prediction for each sparse_combiner.\n    # With sparse_combiner = 'sum', we have\n    # logits_1 = w_a + w_c + bias\n    #          = 2.0 + 5.0 + 5.0 = 12.0\n    # logits_2 = w_b + w_c + bias\n    #          = 3.0 + 5.0 + 5.0 = 13.0\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=feature_columns, model_dir=self._model_dir)\n    predictions = linear_regressor.predict(input_fn=_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    self.assertAllClose([[12.0], [13.0]], predicted_scores)\n\n    # With sparse_combiner = 'mean', we have\n    # logits_1 = 1/2 * (w_a + w_c) + bias\n    #          = 1/2 * (2.0 + 5.0) + 5.0 = 8.5\n    # logits_2 = 1/2 * (w_b + w_c) + bias\n    #          = 1/2 * (3.0 + 5.0) + 5.0 = 9.0\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        model_dir=self._model_dir,\n        sparse_combiner='mean')\n    predictions = linear_regressor.predict(input_fn=_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    self.assertAllClose([[8.5], [9.0]], predicted_scores)\n\n    # With sparse_combiner = 'sqrtn', we have\n    # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias\n    #          = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974\n    # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias\n    #          = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        model_dir=self._model_dir,\n        sparse_combiner='sqrtn')\n    predictions = linear_regressor.predict(input_fn=_input_fn)\n    predicted_scores = list([x['predictions'] for x in predictions])\n    self.assertAllClose([[9.94974], [10.65685]], predicted_scores)\n\n\nclass BaseLinearRegressorIntegrationTest(object):\n\n  def __init__(self, linear_regressor_fn, fc_lib=feature_column):\n    self._linear_regressor_fn = linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,\n                          input_dimension, label_dimension, prediction_length):\n    feature_columns = [\n        self._fc_lib.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['predictions'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, label_dimension), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def test_numpy_input_fn(self):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n  def test_pandas_input_fn(self):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n\n    # Pandas DataFrame natually supports 1 dim data only.\n    label_dimension = 1\n    input_dimension = label_dimension\n    batch_size = 10\n    data = np.array([1., 2., 3., 4.], dtype=np.float32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(data)\n    prediction_length = 4\n\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n  def test_input_fn_from_parse_example(self):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n\n    serialized_examples = []\n    for datum in data:\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=datum)),\n                  'y':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(\n                              value=datum[:label_dimension])),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([label_dimension], tf.dtypes.float32),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        label_dimension=label_dimension,\n        prediction_length=prediction_length)\n\n\nclass BaseLinearRegressorTrainingTest(object):\n\n  def __init__(self, linear_regressor_fn, fc_lib=feature_column):\n    self._linear_regressor_fn = linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n  def _mock_optimizer(self, expected_loss=None):\n    expected_var_names = [\n        '%s/part_0:0' % AGE_WEIGHT_NAME,\n        '%s/part_0:0' % BIAS_NAME\n    ]\n\n    def _minimize(loss, global_step=None, var_list=None):\n      trainable_vars = var_list or tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      self.assertItemsEqual(expected_var_names,\n                            [var.name for var in trainable_vars])\n\n      # Verify loss. We can't check the value directly, so we add an assert op.\n      self.assertEquals(0, loss.shape.ndims)\n      if expected_loss is None:\n        if global_step is not None:\n          return tf.compat.v1.assign_add(global_step, 1).op\n        return tf.no_op()\n      assert_loss = assert_close(\n          tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n          loss,\n          name='assert_loss')\n      with tf.control_dependencies((assert_loss,)):\n        if global_step is not None:\n          return tf.compat.v1.assign_add(global_step, 1).op\n        return tf.no_op()\n\n    mock_optimizer = tf.compat.v1.test.mock.NonCallableMock(\n        spec=tf.compat.v1.train.Optimizer,\n        wraps=tf.compat.v1.train.Optimizer(\n            use_locking=False, name='my_optimizer'))\n    mock_optimizer.minimize = tf.compat.v1.test.mock.MagicMock(wraps=_minimize)\n\n    # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.\n    # So, return mock_optimizer itself for deepcopy.\n    mock_optimizer.__deepcopy__ = lambda _: mock_optimizer\n    return mock_optimizer\n\n  def _assert_checkpoint(self,\n                         expected_global_step,\n                         expected_age_weight=None,\n                         expected_bias=None):\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([1, 1], shapes[AGE_WEIGHT_NAME])\n    if expected_age_weight is not None:\n      self.assertEqual(expected_age_weight,\n                       tf.train.load_variable(self._model_dir, AGE_WEIGHT_NAME))\n\n    self.assertEqual([1], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertEqual(expected_bias,\n                       tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def testFromScratchWithDefaultOptimizer(self):\n    # Create LinearRegressor.\n    label = 5.\n    age = 17\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    linear_regressor.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self._assert_checkpoint(num_steps)\n\n  def testTrainWithOneDimLabel(self):\n    label_dimension = 1\n    batch_size = 20\n    feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir)\n    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)\n    self.assertEqual((batch_size,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(200)\n\n  def testTrainWithOneDimWeight(self):\n    label_dimension = 1\n    batch_size = 20\n    feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]\n    est = self._linear_regressor_fn(\n        feature_columns=feature_columns,\n        label_dimension=label_dimension,\n        weight_column='w',\n        model_dir=self._model_dir)\n\n    data_rank_1 = np.linspace(0., 2., batch_size, dtype=np.float32)\n    self.assertEqual((batch_size,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_1\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(200)\n\n  def testFromScratch(self):\n    # Create LinearRegressor.\n    label = 5.\n    age = 17\n    # loss = (logits - label)^2 = (0 - 5.)^2 = 25.\n    mock_optimizer = self._mock_optimizer(expected_loss=25.)\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir,\n        optimizer=mock_optimizer)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    linear_regressor.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        expected_global_step=num_steps,\n        expected_age_weight=0.,\n        expected_bias=0.)\n\n  def testFromCheckpoint(self):\n    # Create initial checkpoint.\n    age_weight = 10.0\n    bias = 5.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([[age_weight]], name=AGE_WEIGHT_NAME)\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = age * age_weight + bias = 17 * 10. + 5. = 175\n    # loss = (logits - label)^2 = (175 - 5)^2 = 28900\n    mock_optimizer = self._mock_optimizer(expected_loss=28900.)\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir,\n        optimizer=mock_optimizer)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    linear_regressor.train(\n        input_fn=lambda: ({\n            'age': ((17,),)\n        }, ((5.,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        expected_global_step=initial_global_step + num_steps,\n        expected_age_weight=age_weight,\n        expected_bias=bias)\n\n  def testFromCheckpointMultiBatch(self):\n    # Create initial checkpoint.\n    age_weight = 10.0\n    bias = 5.0\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable([[age_weight]], name=AGE_WEIGHT_NAME)\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = age * age_weight + bias\n    # logits[0] = 17 * 10. + 5. = 175\n    # logits[1] = 15 * 10. + 5. = 155\n    # loss = sum(logits - label)^2 = (175 - 5)^2 + (155 - 3)^2 = 52004\n    mock_optimizer = self._mock_optimizer(expected_loss=52004.)\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        model_dir=self._model_dir,\n        optimizer=mock_optimizer)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    linear_regressor.train(\n        input_fn=lambda: ({\n            'age': ((17,), (15,))\n        }, ((5.,), (3.,))),\n        steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        expected_global_step=initial_global_step + num_steps,\n        expected_age_weight=age_weight,\n        expected_bias=bias)\n\n\nclass BaseLinearClassifierTrainingTest(object):\n\n  def __init__(self, linear_classifier_fn, fc_lib=feature_column):\n    self._linear_classifier_fn = linear_classifier_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _mock_optimizer(self, expected_loss=None):\n    expected_var_names = [\n        '%s/part_0:0' % AGE_WEIGHT_NAME,\n        '%s/part_0:0' % BIAS_NAME\n    ]\n\n    def _minimize(loss, global_step):\n      trainable_vars = tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES)\n      self.assertItemsEqual(expected_var_names,\n                            [var.name for var in trainable_vars])\n\n      # Verify loss. We can't check the value directly, so we add an assert op.\n      self.assertEquals(0, loss.shape.ndims)\n      if expected_loss is None:\n        return tf.compat.v1.assign_add(global_step, 1).op\n      assert_loss = assert_close(\n          tf.cast(expected_loss, name='expected', dtype=tf.dtypes.float32),\n          loss,\n          name='assert_loss')\n      with tf.control_dependencies((assert_loss,)):\n        return tf.compat.v1.assign_add(global_step, 1).op\n\n    mock_optimizer = tf.compat.v1.test.mock.NonCallableMock(\n        spec=tf.compat.v1.train.Optimizer,\n        wraps=tf.compat.v1.train.Optimizer(\n            use_locking=False, name='my_optimizer'))\n    mock_optimizer.minimize = tf.compat.v1.test.mock.MagicMock(wraps=_minimize)\n\n    # NOTE: Estimator.params performs a deepcopy, which wreaks havoc with mocks.\n    # So, return mock_optimizer itself for deepcopy.\n    mock_optimizer.__deepcopy__ = lambda _: mock_optimizer\n    return mock_optimizer\n\n  def _assert_checkpoint(self,\n                         n_classes,\n                         expected_global_step,\n                         expected_age_weight=None,\n                         expected_bias=None):\n    logits_dimension = n_classes if n_classes > 2 else 1\n\n    shapes = {\n        name: shape\n        for (name, shape) in tf.train.list_variables(self._model_dir)\n    }\n\n    self.assertEqual([], shapes[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertEqual(\n        expected_global_step,\n        tf.train.load_variable(self._model_dir,\n                               tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n    self.assertEqual([1, logits_dimension], shapes[AGE_WEIGHT_NAME])\n    if expected_age_weight is not None:\n      self.assertAllEqual(\n          expected_age_weight,\n          tf.train.load_variable(self._model_dir, AGE_WEIGHT_NAME))\n\n    self.assertEqual([logits_dimension], shapes[BIAS_NAME])\n    if expected_bias is not None:\n      self.assertAllEqual(expected_bias,\n                          tf.train.load_variable(self._model_dir, BIAS_NAME))\n\n  def _testFromScratchWithDefaultOptimizer(self, n_classes):\n    label = 0\n    age = 17\n    est = linear.LinearClassifier(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # Train for a few steps, and validate final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self._assert_checkpoint(n_classes, num_steps)\n\n  def testBinaryClassesFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=2)\n\n  def testMultiClassesFromScratchWithDefaultOptimizer(self):\n    self._testFromScratchWithDefaultOptimizer(n_classes=4)\n\n  def _testTrainWithTwoDimsLabel(self, n_classes):\n    batch_size = 20\n\n    est = linear.LinearClassifier(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    data_rank_2 = np.array([[0], [1]])\n    self.assertEqual((2,), data_rank_1.shape)\n    self.assertEqual((2, 1), data_rank_2.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_2,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithTwoDimsLabel(self):\n    self._testTrainWithTwoDimsLabel(n_classes=2)\n\n  def testMultiClassesTrainWithTwoDimsLabel(self):\n    self._testTrainWithTwoDimsLabel(n_classes=4)\n\n  def _testTrainWithOneDimLabel(self, n_classes):\n    batch_size = 20\n\n    est = linear.LinearClassifier(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    self.assertEqual((2,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'age': data_rank_1},\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithOneDimLabel(self):\n    self._testTrainWithOneDimLabel(n_classes=2)\n\n  def testMultiClassesTrainWithOneDimLabel(self):\n    self._testTrainWithOneDimLabel(n_classes=4)\n\n  def _testTrainWithTwoDimsWeight(self, n_classes):\n    batch_size = 20\n\n    est = linear.LinearClassifier(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        weight_column='w',\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    data_rank_2 = np.array([[0], [1]])\n    self.assertEqual((2,), data_rank_1.shape)\n    self.assertEqual((2, 1), data_rank_2.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_2\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithTwoDimsWeight(self):\n    self._testTrainWithTwoDimsWeight(n_classes=2)\n\n  def testMultiClassesTrainWithTwoDimsWeight(self):\n    self._testTrainWithTwoDimsWeight(n_classes=4)\n\n  def _testTrainWithOneDimWeight(self, n_classes):\n    batch_size = 20\n\n    est = linear.LinearClassifier(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        weight_column='w',\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    data_rank_1 = np.array([0, 1])\n    self.assertEqual((2,), data_rank_1.shape)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={\n            'age': data_rank_1,\n            'w': data_rank_1\n        },\n        y=data_rank_1,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    est.train(train_input_fn, steps=200)\n    self._assert_checkpoint(n_classes, 200)\n\n  def testBinaryClassesTrainWithOneDimWeight(self):\n    self._testTrainWithOneDimWeight(n_classes=2)\n\n  def testMultiClassesTrainWithOneDimWeight(self):\n    self._testTrainWithOneDimWeight(n_classes=4)\n\n  def _testFromScratch(self, n_classes):\n    label = 1\n    age = 17\n    # For binary classifier:\n    #   loss = sigmoid_cross_entropy(logits, label) where logits=0 (weights are\n    #   all zero initially) and label = 1 so,\n    #      loss = 1 * -log ( sigmoid(logits) ) = 0.69315\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label) where logits are all 0s (weights are\n    #   all zero initially) and label = 1 so,\n    #      loss = 1 * -log ( 1.0 / n_classes )\n    # For this particular test case, as logits are same, the formular\n    # 1 * -log ( 1.0 / n_classes ) covers both binary and multi class cases.\n    mock_optimizer = self._mock_optimizer(\n        expected_loss=(-1 * math.log(1.0 / n_classes)))\n\n    est = linear.LinearClassifier(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=num_steps,\n        expected_age_weight=[[0.]] if n_classes == 2 else [[0.] * n_classes],\n        expected_bias=[0.] if n_classes == 2 else [.0] * n_classes)\n\n  def testBinaryClassesFromScratch(self):\n    self._testFromScratch(n_classes=2)\n\n  def testMultiClassesFromScratch(self):\n    self._testFromScratch(n_classes=4)\n\n  def _testFromCheckpoint(self, n_classes):\n    # Create initial checkpoint.\n    label = 1\n    age = 17\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[2.0]] if n_classes == 2 else (np.reshape(\n        2.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # For binary classifier:\n    #   logits = age * age_weight + bias = 17 * 2. - 35. = -1.\n    #   loss = sigmoid_cross_entropy(logits, label)\n    #   so, loss = 1 * -log ( sigmoid(-1) ) = 1.3133\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label)\n    #   where logits = 17 * age_weight + bias and label = 1\n    #   so, loss = 1 * -log ( soft_max(logits)[1] )\n    if n_classes == 2:\n      expected_loss = 1.3133\n    else:\n      logits = age_weight * age + bias\n      logits_exp = np.exp(logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_loss = -1 * math.log(softmax[0, label])\n\n    mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)\n\n    est = linear.LinearClassifier(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=initial_global_step + num_steps,\n        expected_age_weight=age_weight,\n        expected_bias=bias)\n\n  def testBinaryClassesFromCheckpoint(self):\n    self._testFromCheckpoint(n_classes=2)\n\n  def testMultiClassesFromCheckpoint(self):\n    self._testFromCheckpoint(n_classes=4)\n\n  def _testFromCheckpointFloatLabels(self, n_classes):\n    \"\"\"Tests float labels for binary classification.\"\"\"\n    # Create initial checkpoint.\n    if n_classes > 2:\n      return\n    label = 0.8\n    age = 17\n    age_weight = [[2.0]]\n    bias = [-35.0]\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # logits = age * age_weight + bias = 17 * 2. - 35. = -1.\n    # loss = sigmoid_cross_entropy(logits, label)\n    # => loss = -0.8 * log(sigmoid(-1)) -0.2 * log(sigmoid(+1)) = 1.1132617\n    mock_optimizer = self._mock_optimizer(expected_loss=1.1132617)\n\n    est = linear.LinearClassifier(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n\n  def testBinaryClassesFromCheckpointFloatLabels(self):\n    self._testFromCheckpointFloatLabels(n_classes=2)\n\n  def testMultiClassesFromCheckpointFloatLabels(self):\n    self._testFromCheckpointFloatLabels(n_classes=4)\n\n  def _testFromCheckpointMultiBatch(self, n_classes):\n    # Create initial checkpoint.\n    label = [1, 0]\n    age = [17.0, 18.5]\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[2.0]] if n_classes == 2 else (np.reshape(\n        2.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    # For binary classifier:\n    #   logits = age * age_weight + bias\n    #   logits[0] = 17 * 2. - 35. = -1.\n    #   logits[1] = 18.5 * 2. - 35. = 2.\n    #   loss = sigmoid_cross_entropy(logits, label)\n    #   so, loss[0] = 1 * -log ( sigmoid(-1) ) = 1.3133\n    #       loss[1] = (1 - 0) * -log ( 1- sigmoid(2) ) = 2.1269\n    #   expected_loss = loss[0] + loss[1]\n    # For multi class classifier:\n    #   loss = cross_entropy(logits, label)\n    #   where logits = [17, 18.5] * age_weight + bias and label = [1, 0]\n    #   so, loss = 1 * -log ( soft_max(logits)[label] )\n    #   expected_loss = loss[0] + loss[1]\n    if n_classes == 2:\n      expected_loss = 1.3133 + 2.1269\n    else:\n      logits = age_weight * np.reshape(age, (2, 1)) + bias\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      expected_loss = expected_loss_0 + expected_loss_1\n\n    mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)\n\n    est = linear.LinearClassifier(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        optimizer=mock_optimizer,\n        model_dir=self._model_dir)\n    self.assertEqual(0, mock_optimizer.minimize.call_count)\n\n    # Train for a few steps, and validate optimizer and final checkpoint.\n    num_steps = 10\n    est.train(input_fn=lambda: ({'age': (age)}, (label)), steps=num_steps)\n    self.assertEqual(1, mock_optimizer.minimize.call_count)\n    self._assert_checkpoint(\n        n_classes,\n        expected_global_step=initial_global_step + num_steps,\n        expected_age_weight=age_weight,\n        expected_bias=bias)\n\n  def testBinaryClassesFromCheckpointMultiBatch(self):\n    self._testFromCheckpointMultiBatch(n_classes=2)\n\n  def testMultiClassesFromCheckpointMultiBatch(self):\n    self._testFromCheckpointMultiBatch(n_classes=4)\n\n\nclass BaseLinearClassifierEvaluationTest(object):\n\n  def __init__(self, linear_classifier_fn, fc_lib=feature_column):\n    self._linear_classifier_fn = linear_classifier_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _test_evaluation_for_simple_data(self, n_classes):\n    label = 1\n    age = 1.\n\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[-11.0]] if n_classes == 2 else (np.reshape(\n        -11.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-30.0] if n_classes == 2 else [-30.0] * n_classes\n\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          100, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = self._linear_classifier_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({\n            'age': ((age,),)\n        }, ((label,),)), steps=1)\n\n    if n_classes == 2:\n      # Binary classes: loss = sum(corss_entropy(41)) = 41.\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: 41.,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: 41.,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: 0.,\n          metric_keys.MetricKeys.LABEL_MEAN: 1.,\n          metric_keys.MetricKeys.ACCURACY_BASELINE: 1,\n          metric_keys.MetricKeys.AUC: 0.,\n          metric_keys.MetricKeys.AUC_PR: 1.,\n      }\n    else:\n      # Multi classes: loss = 1 * -log ( soft_max(logits)[label] )\n      logits = age_weight * age + bias\n      logits_exp = np.exp(logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_loss = -1 * math.log(softmax[0, label])\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_for_simple_data(self):\n    self._test_evaluation_for_simple_data(n_classes=2)\n\n  def test_multi_classes_evaluation_for_simple_data(self):\n    self._test_evaluation_for_simple_data(n_classes=4)\n\n  def _test_evaluation_batch(self, n_classes):\n    \"\"\"Tests evaluation for batch_size==2.\"\"\"\n    label = [1, 0]\n    age = [17., 18.]\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[2.0]] if n_classes == 2 else (np.reshape(\n        2.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = self._linear_classifier_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({\n            'age': (age)\n        }, (label)), steps=1)\n\n    if n_classes == 2:\n      # Logits are (-1., 1.) labels are (1, 0).\n      # Loss is\n      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133\n      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(1)) = 1.3133\n      expected_loss = 1.3133 * 2\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: 0.5,\n          metric_keys.MetricKeys.LABEL_MEAN: 0.5,\n          metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5,\n          metric_keys.MetricKeys.AUC: 0.,\n          metric_keys.MetricKeys.AUC_PR: 0.25,\n      }\n    else:\n      # Multi classes: loss = 1 * -log ( soft_max(logits)[label] )\n      logits = age_weight * np.reshape(age, (2, 1)) + bias\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      expected_loss = expected_loss_0 + expected_loss_1\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_batch(self):\n    self._test_evaluation_batch(n_classes=2)\n\n  def test_multi_classes_evaluation_batch(self):\n    self._test_evaluation_batch(n_classes=4)\n\n  def _test_evaluation_weights(self, n_classes):\n    \"\"\"Tests evaluation with weights.\"\"\"\n\n    label = [1, 0]\n    age = [17., 18.]\n    weights = [1., 2.]\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[2.0]] if n_classes == 2 else (np.reshape(\n        2.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [-35.0] if n_classes == 2 else [-35.0] * n_classes\n    initial_global_step = 100\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(\n          initial_global_step,\n          name=tf.compat.v1.GraphKeys.GLOBAL_STEP,\n          dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = self._linear_classifier_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        n_classes=n_classes,\n        weight_column='w',\n        model_dir=self._model_dir)\n    eval_metrics = est.evaluate(\n        input_fn=lambda: ({\n            'age': (age),\n            'w': (weights)\n        }, (label)), steps=1)\n\n    if n_classes == 2:\n      # Logits are (-1., 1.) labels are (1, 0).\n      # Loss is\n      #   loss for row 1: 1 * -log(sigmoid(-1)) = 1.3133\n      #   loss for row 2: (1 - 0) * -log(1 - sigmoid(1)) = 1.3133\n      #   weights = [1., 2.]\n      expected_loss = 1.3133 * (1. + 2.)\n      loss_mean = expected_loss / (1.0 + 2.0)\n      label_mean = np.average(label, weights=weights)\n      logits = [-1, 1]\n      logistics = sigmoid(np.array(logits))\n      predictions_mean = np.average(logistics, weights=weights)\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n          metric_keys.MetricKeys.PRECISION: 0.,\n          metric_keys.MetricKeys.RECALL: 0.,\n          metric_keys.MetricKeys.PREDICTION_MEAN: predictions_mean,\n          metric_keys.MetricKeys.LABEL_MEAN: label_mean,\n          metric_keys.MetricKeys.ACCURACY_BASELINE:\n              (max(label_mean, 1 - label_mean)),\n          metric_keys.MetricKeys.AUC: 0.,\n          metric_keys.MetricKeys.AUC_PR: 0.1668,\n      }\n    else:\n      # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] )\n      logits = age_weight * np.reshape(age, (2, 1)) + bias\n      logits_exp = np.exp(logits)\n      softmax_row_0 = logits_exp[0] / logits_exp[0].sum()\n      softmax_row_1 = logits_exp[1] / logits_exp[1].sum()\n      expected_loss_0 = -1 * math.log(softmax_row_0[label[0]])\n      expected_loss_1 = -1 * math.log(softmax_row_1[label[1]])\n      loss_mean = np.average([expected_loss_0, expected_loss_1],\n                             weights=weights)\n      expected_loss = loss_mean * np.sum(weights)\n\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n          metric_keys.MetricKeys.LOSS_MEAN: loss_mean,\n          tf.compat.v1.GraphKeys.GLOBAL_STEP: 100,\n          metric_keys.MetricKeys.ACCURACY: 0.,\n      }\n\n    self.assertAllClose(\n        sorted_key_dict(expected_metrics),\n        sorted_key_dict(eval_metrics),\n        rtol=1e-3)\n\n  def test_binary_classes_evaluation_weights(self):\n    self._test_evaluation_weights(n_classes=2)\n\n  def test_multi_classes_evaluation_weights(self):\n    self._test_evaluation_weights(n_classes=4)\n\n\nclass BaseLinearClassifierPredictTest(object):\n\n  def __init__(self, linear_classifier_fn, fc_lib=feature_column):\n    self._linear_classifier_fn = linear_classifier_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _testPredictions(self, n_classes, label_vocabulary, label_output_fn):\n    \"\"\"Tests predict when all variables are one-dimensional.\"\"\"\n    age = 1.\n\n    # For binary case, the expected weight has shape (1,1). For multi class\n    # case, the shape is (1, n_classes). In order to test the weights, set\n    # weights as 2.0 * range(n_classes).\n    age_weight = [[-11.0]] if n_classes == 2 else (np.reshape(\n        -11.0 * np.array(list(range(n_classes)), dtype=np.float32),\n        (1, n_classes)))\n    bias = [10.0] if n_classes == 2 else [10.0] * n_classes\n\n    with tf.Graph().as_default():\n      tf.Variable(age_weight, name=AGE_WEIGHT_NAME)\n      tf.Variable(bias, name=BIAS_NAME)\n      tf.Variable(100, name='global_step', dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    est = self._linear_classifier_fn(\n        feature_columns=(self._fc_lib.numeric_column('age'),),\n        label_vocabulary=label_vocabulary,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'age': np.array([[age]])},\n        y=None,\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    predictions = list(est.predict(input_fn=predict_input_fn))\n\n    if n_classes == 2:\n      scalar_logits = np.reshape(np.array(age_weight) * age + bias, (1,)).item()\n      two_classes_logits = [0, scalar_logits]\n      two_classes_logits_exp = np.exp(two_classes_logits)\n      softmax = two_classes_logits_exp / two_classes_logits_exp.sum()\n\n      expected_predictions = {\n          'class_ids': [0],\n          'all_class_ids': [0, 1],\n          'classes': [label_output_fn(0)],\n          'all_classes': [label_output_fn(0),\n                          label_output_fn(1)],\n          'logistic': [sigmoid(np.array(scalar_logits))],\n          'logits': [scalar_logits],\n          'probabilities': softmax,\n      }\n    else:\n      onedim_logits = np.reshape(np.array(age_weight) * age + bias, (-1,))\n      class_ids = onedim_logits.argmax()\n      all_class_ids = list(range(len(onedim_logits)))\n      logits_exp = np.exp(onedim_logits)\n      softmax = logits_exp / logits_exp.sum()\n      expected_predictions = {\n          'class_ids': [class_ids],\n          'all_class_ids': all_class_ids,\n          'classes': [label_output_fn(class_ids)],\n          'all_classes': [label_output_fn(i) for i in all_class_ids],\n          'logits': onedim_logits,\n          'probabilities': softmax,\n      }\n\n    self.assertEqual(1, len(predictions))\n    # assertAllClose cannot handle byte type.\n    self.assertEqual(expected_predictions['classes'], predictions[0]['classes'])\n    expected_predictions.pop('classes')\n    predictions[0].pop('classes')\n    self.assertAllEqual(expected_predictions['all_classes'],\n                        predictions[0]['all_classes'])\n    expected_predictions.pop('all_classes')\n    predictions[0].pop('all_classes')\n    self.assertAllClose(\n        sorted_key_dict(expected_predictions), sorted_key_dict(predictions[0]))\n\n  def testBinaryClassesWithoutLabelVocabulary(self):\n    n_classes = 2\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=None,\n        label_output_fn=lambda x: ('%s' % x).encode())\n\n  def testBinaryClassesWithLabelVocabulary(self):\n    n_classes = 2\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n  def testMultiClassesWithoutLabelVocabulary(self):\n    n_classes = 4\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=None,\n        label_output_fn=lambda x: ('%s' % x).encode())\n\n  def testMultiClassesWithLabelVocabulary(self):\n    n_classes = 4\n    self._testPredictions(\n        n_classes,\n        label_vocabulary=['class_vocab_{}'.format(i) for i in range(n_classes)],\n        label_output_fn=lambda x: ('class_vocab_%s' % x).encode())\n\n  def testSparseCombiner(self):\n    w_a = 2.0\n    w_b = 3.0\n    w_c = 5.0\n    bias = 5.0\n    with tf.Graph().as_default():\n      tf.Variable([[w_a], [w_b], [w_c]], name=LANGUAGE_WEIGHT_NAME)\n      tf.Variable([bias], name=BIAS_NAME)\n      tf.Variable(\n          1, name=tf.compat.v1.GraphKeys.GLOBAL_STEP, dtype=tf.dtypes.int64)\n      save_variables_to_ckpt(self._model_dir)\n\n    def _input_fn():\n      return tf.compat.v1.data.Dataset.from_tensors({\n          'language':\n              tf.sparse.SparseTensor(\n                  values=['a', 'c', 'b', 'c'],\n                  indices=[[0, 0], [0, 1], [1, 0], [1, 1]],\n                  dense_shape=[2, 2]),\n      })\n\n    feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(\n        'language', vocabulary_list=['a', 'b', 'c']),)\n\n    # Check prediction for each sparse_combiner.\n    # With sparse_combiner = 'sum', we have\n    # logits_1 = w_a + w_c + bias\n    #          = 2.0 + 5.0 + 5.0 = 12.0\n    # logits_2 = w_b + w_c + bias\n    #          = 3.0 + 5.0 + 5.0 = 13.0\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=feature_columns, model_dir=self._model_dir)\n    predictions = linear_classifier.predict(input_fn=_input_fn)\n    predicted_scores = list([x['logits'] for x in predictions])\n    self.assertAllClose([[12.0], [13.0]], predicted_scores)\n\n    # With sparse_combiner = 'mean', we have\n    # logits_1 = 1/2 * (w_a + w_c) + bias\n    #          = 1/2 * (2.0 + 5.0) + 5.0 = 8.5\n    # logits_2 = 1/2 * (w_b + w_c) + bias\n    #          = 1/2 * (3.0 + 5.0) + 5.0 = 9.0\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=feature_columns,\n        model_dir=self._model_dir,\n        sparse_combiner='mean')\n    predictions = linear_classifier.predict(input_fn=_input_fn)\n    predicted_scores = list([x['logits'] for x in predictions])\n    self.assertAllClose([[8.5], [9.0]], predicted_scores)\n\n    # With sparse_combiner = 'sqrtn', we have\n    # logits_1 = sqrt(2)/2 * (w_a + w_c) + bias\n    #          = sqrt(2)/2 * (2.0 + 5.0) + 5.0 = 9.94974\n    # logits_2 = sqrt(2)/2 * (w_b + w_c) + bias\n    #          = sqrt(2)/2 * (3.0 + 5.0) + 5.0 = 10.65685\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=feature_columns,\n        model_dir=self._model_dir,\n        sparse_combiner='sqrtn')\n    predictions = linear_classifier.predict(input_fn=_input_fn)\n    predicted_scores = list([x['logits'] for x in predictions])\n    self.assertAllClose([[9.94974], [10.65685]], predicted_scores)\n\n\nclass BaseLinearClassifierIntegrationTest(object):\n\n  def __init__(self, linear_classifier_fn, fc_lib=feature_column):\n    self._linear_classifier_fn = linear_classifier_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,\n                          predict_input_fn, input_dimension, prediction_length):\n    feature_columns = [\n        self._fc_lib.numeric_column('x', shape=(input_dimension,))\n    ]\n    est = self._linear_classifier_fn(\n        feature_columns=feature_columns,\n        n_classes=n_classes,\n        model_dir=self._model_dir)\n\n    # TRAIN\n    # learn y = x\n    est.train(train_input_fn, steps=200)\n\n    # EVALUTE\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS, six.iterkeys(scores))\n\n    # PREDICT\n    predictions = np.array(\n        [x['classes'] for x in est.predict(predict_input_fn)])\n    self.assertAllEqual((prediction_length, 1), predictions.shape)\n\n    # EXPORT\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def _test_numpy_input_fn(self, n_classes):\n    \"\"\"Tests complete flow with numpy_input_fn.\"\"\"\n    input_dimension = 4\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n    target = np.array([1] * batch_size)\n\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=target,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=target,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=None,\n        batch_size=batch_size,\n        num_epochs=1,\n        shuffle=False)\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_numpy_input_fn(self):\n    self._test_numpy_input_fn(n_classes=2)\n\n  def test_multi_classes_numpy_input_fn(self):\n    self._test_numpy_input_fn(n_classes=4)\n\n  def _test_pandas_input_fn(self, n_classes):\n    \"\"\"Tests complete flow with pandas_input_fn.\"\"\"\n    if not HAS_PANDAS:\n      return\n\n    # Pandas DataFrame natually supports 1 dim data only.\n    input_dimension = 1\n    batch_size = 10\n    data = np.array([1., 2., 3., 4.], dtype=np.float32)\n    target = np.array([1, 0, 1, 0], dtype=np.int32)\n    x = pd.DataFrame({'x': data})\n    y = pd.Series(target)\n    prediction_length = 4\n\n    train_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, num_epochs=None, shuffle=True)\n    eval_input_fn = pandas_io.pandas_input_fn(\n        x=x, y=y, batch_size=batch_size, shuffle=False)\n    predict_input_fn = pandas_io.pandas_input_fn(\n        x=x, batch_size=batch_size, shuffle=False)\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_pandas_input_fn(self):\n    self._test_pandas_input_fn(n_classes=2)\n\n  def test_multi_classes_pandas_input_fn(self):\n    self._test_pandas_input_fn(n_classes=4)\n\n  def _test_input_fn_from_parse_example(self, n_classes):\n    \"\"\"Tests complete flow with input_fn constructed from parse_example.\"\"\"\n    input_dimension = 2\n    batch_size = 10\n    prediction_length = batch_size\n    data = np.linspace(0., 2., batch_size * input_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, input_dimension)\n    target = np.array([1] * batch_size, dtype=np.int64)\n\n    serialized_examples = []\n    for x, y in zip(data, target):\n      example = example_pb2.Example(\n          features=feature_pb2.Features(\n              feature={\n                  'x':\n                      feature_pb2.Feature(\n                          float_list=feature_pb2.FloatList(value=x)),\n                  'y':\n                      feature_pb2.Feature(\n                          int64_list=feature_pb2.Int64List(value=[y])),\n              }))\n      serialized_examples.append(example.SerializeToString())\n\n    feature_spec = {\n        'x': tf.io.FixedLenFeature([input_dimension], tf.dtypes.float32),\n        'y': tf.io.FixedLenFeature([1], tf.dtypes.int64),\n    }\n\n    def _train_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(serialized_examples,\n                                                  feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _eval_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      labels = features.pop('y')\n      return features, labels\n\n    def _predict_input_fn():\n      feature_map = tf.compat.v1.io.parse_example(\n          tf.compat.v1.train.limit_epochs(serialized_examples, num_epochs=1),\n          feature_spec)\n      features = queue_parsed_features(feature_map)\n      features.pop('y')\n      return features, None\n\n    self._test_complete_flow(\n        n_classes=n_classes,\n        train_input_fn=_train_input_fn,\n        eval_input_fn=_eval_input_fn,\n        predict_input_fn=_predict_input_fn,\n        input_dimension=input_dimension,\n        prediction_length=prediction_length)\n\n  def test_binary_classes_input_fn_from_parse_example(self):\n    self._test_input_fn_from_parse_example(n_classes=2)\n\n  def test_multi_classes_input_fn_from_parse_example(self):\n    self._test_input_fn_from_parse_example(n_classes=4)\n\n\nclass BaseLinearLogitFnTest(object):\n\n  def __init__(self, fc_lib=feature_column):\n    self._fc_lib = fc_lib\n\n  def test_basic_logit_correctness(self):\n    \"\"\"linear_logit_fn simply wraps feature_column_lib.linear_model.\"\"\"\n    age = self._fc_lib.numeric_column('age')\n    with tf.Graph().as_default():\n      logit_fn = linear.linear_logit_fn_builder(units=2, feature_columns=[age])\n      logits = logit_fn(features={'age': [[23.], [31.]]})\n      bias_var = tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,\n          'linear_model/bias_weights')[0]\n      age_var = tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, 'linear_model/age')[0]\n      with tf.compat.v1.Session() as sess:\n        sess.run([tf.compat.v1.initializers.global_variables()])\n        self.assertAllClose([[0., 0.], [0., 0.]], logits.eval())\n        sess.run(bias_var.assign([10., 5.]))\n        self.assertAllClose([[10., 5.], [10., 5.]], logits.eval())\n        sess.run(age_var.assign([[2.0, 3.0]]))\n        # [2 * 23 + 10, 3 * 23 + 5] = [56, 74].\n        # [2 * 31 + 10, 3 * 31 + 5] = [72, 98]\n        self.assertAllClose([[56., 74.], [72., 98.]], logits.eval())\n\n  def test_compute_fraction_of_zero(self):\n    \"\"\"Tests the calculation of sparsity.\"\"\"\n    if self._fc_lib != feature_column:\n      return\n    age = tf.feature_column.numeric_column('age')\n    occupation = feature_column.categorical_column_with_hash_bucket(\n        'occupation', hash_bucket_size=5)\n    with tf.Graph().as_default():\n      cols_to_vars = {}\n      tf.compat.v1.feature_column.linear_model(\n          features={\n              'age': [[23.], [31.]],\n              'occupation': [['doctor'], ['engineer']]\n          },\n          feature_columns=[age, occupation],\n          units=3,\n          cols_to_vars=cols_to_vars)\n      cols_to_vars.pop('bias')\n      fraction_zero = linear._compute_fraction_of_zero(\n          list(cols_to_vars.values()))\n      age_var = tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, 'linear_model/age')[0]\n      with tf.compat.v1.Session() as sess:\n        sess.run([tf.compat.v1.initializers.global_variables()])\n        # Upon initialization, all variables will be zero.\n        self.assertAllClose(1, fraction_zero.eval())\n\n        sess.run(age_var.assign([[2.0, 0.0, -1.0]]))\n        # 1 of the 3 age weights are zero, and all of the 15 (5 hash buckets\n        # x 3-dim output) are zero.\n        self.assertAllClose(16. / 18., fraction_zero.eval())\n\n  def test_compute_fraction_of_zero_v2(self):\n    \"\"\"Tests the calculation of sparsity.\"\"\"\n    if self._fc_lib != feature_column_v2:\n      return\n\n    age = tf.feature_column.numeric_column('age')\n    occupation = tf.feature_column.categorical_column_with_hash_bucket(\n        'occupation', hash_bucket_size=5)\n    with tf.Graph().as_default():\n      model = feature_column_v2.LinearModel(\n          feature_columns=[age, occupation], units=3, name='linear_model')\n      features = {\n          'age': [[23.], [31.]],\n          'occupation': [['doctor'], ['engineer']]\n      }\n      model(features)\n      variables = model.variables\n      variables.remove(model.bias)\n      fraction_zero = linear._compute_fraction_of_zero(variables)\n      age_var = tf.compat.v1.get_collection(\n          tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, 'linear_model/age')[0]\n      with tf.compat.v1.Session() as sess:\n        sess.run([tf.compat.v1.initializers.global_variables()])\n        # Upon initialization, all variables will be zero.\n        self.assertAllClose(1, fraction_zero.eval())\n\n        sess.run(age_var.assign([[2.0, 0.0, -1.0]]))\n        # 1 of the 3 age weights are zero, and all of the 15 (5 hash buckets\n        # x 3-dim output) are zero.\n        self.assertAllClose(16. / 18., fraction_zero.eval())\n\n\nclass BaseLinearWarmStartingTest(object):\n\n  def __init__(self,\n               _linear_classifier_fn,\n               _linear_regressor_fn,\n               fc_lib=feature_column):\n    self._linear_classifier_fn = _linear_classifier_fn\n    self._linear_regressor_fn = _linear_regressor_fn\n    self._fc_lib = fc_lib\n\n  def setUp(self):\n    # Create a directory to save our old checkpoint and vocabularies to.\n    self._ckpt_and_vocab_dir = tempfile.mkdtemp()\n\n    # Make a dummy input_fn.\n    def _input_fn():\n      features = {\n          'age': [[23.], [31.]],\n          'age_in_years': [[23.], [31.]],\n          'occupation': [['doctor'], ['consultant']]\n      }\n      return features, [0, 1]\n\n    self._input_fn = _input_fn\n\n  def tearDown(self):\n    # Clean up checkpoint / vocab dir.\n    tf.compat.v1.summary.FileWriterCache.clear()\n    shutil.rmtree(self._ckpt_and_vocab_dir)\n\n  def test_classifier_basic_warm_starting(self):\n    \"\"\"Tests correctness of LinearClassifier default warm-start.\"\"\"\n    age = self._fc_lib.numeric_column('age')\n\n    # Create a LinearClassifier and train to save a checkpoint.\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age],\n        n_classes=4,\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        warm_start_from=linear_classifier.model_dir)\n\n    warm_started_linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_linear_classifier.get_variable_names():\n      self.assertAllClose(\n          linear_classifier.get_variable_value(variable_name),\n          warm_started_linear_classifier.get_variable_value(variable_name))\n\n  def test_regressor_basic_warm_starting(self):\n    \"\"\"Tests correctness of LinearRegressor default warm-start.\"\"\"\n    age = self._fc_lib.numeric_column('age')\n\n    # Create a LinearRegressor and train to save a checkpoint.\n    linear_regressor = self._linear_regressor_fn(\n        feature_columns=[age],\n        model_dir=self._ckpt_and_vocab_dir,\n        optimizer='SGD')\n    linear_regressor.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearRegressor, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_linear_regressor = self._linear_regressor_fn(\n        feature_columns=[age],\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        warm_start_from=linear_regressor.model_dir)\n\n    warm_started_linear_regressor.train(input_fn=self._input_fn, max_steps=1)\n    for variable_name in warm_started_linear_regressor.get_variable_names():\n      self.assertAllClose(\n          linear_regressor.get_variable_value(variable_name),\n          warm_started_linear_regressor.get_variable_value(variable_name))\n\n  def test_warm_starting_selective_variables(self):\n    \"\"\"Tests selecting variables to warm-start.\"\"\"\n    age = self._fc_lib.numeric_column('age')\n\n    # Create a LinearClassifier and train to save a checkpoint.\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age],\n        n_classes=4,\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        # The provided regular expression will only warm-start the age variable\n        # and not the bias.\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=linear_classifier.model_dir,\n            vars_to_warm_start='.*(age).*'))\n\n    warm_started_linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n    self.assertAllClose(\n        linear_classifier.get_variable_value(AGE_WEIGHT_NAME),\n        warm_started_linear_classifier.get_variable_value(AGE_WEIGHT_NAME))\n    # Bias should still be zero from initialization.\n    self.assertAllClose(\n        [0.0] * 4, warm_started_linear_classifier.get_variable_value(BIAS_NAME))\n\n  def test_warm_starting_with_vocab_remapping_and_partitioning(self):\n    \"\"\"Tests warm-starting with vocab remapping and partitioning.\"\"\"\n    vocab_list = ['doctor', 'lawyer', 'consultant']\n    vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')\n    with open(vocab_file, 'w') as f:\n      f.write('\\n'.join(vocab_list))\n    occupation = self._fc_lib.categorical_column_with_vocabulary_file(\n        'occupation',\n        vocabulary_file=vocab_file,\n        vocabulary_size=len(vocab_list))\n\n    # Create a LinearClassifier and train to save a checkpoint.\n    partitioner = tf.compat.v1.fixed_size_partitioner(num_shards=2)\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=[occupation],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD',\n        partitioner=partitioner)\n    linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).  Use a new FeatureColumn with a\n    # different vocabulary for occupation.\n    new_vocab_list = ['doctor', 'consultant', 'engineer']\n    new_vocab_file = os.path.join(self._ckpt_and_vocab_dir,\n                                  'new_occupation_vocab')\n    with open(new_vocab_file, 'w') as f:\n      f.write('\\n'.join(new_vocab_list))\n    new_occupation = self._fc_lib.categorical_column_with_vocabulary_file(\n        'occupation',\n        vocabulary_file=new_vocab_file,\n        vocabulary_size=len(new_vocab_list))\n    # We can create our VocabInfo object from the new and old occupation\n    # FeatureColumn's.\n    occupation_vocab_info = estimator.VocabInfo(\n        new_vocab=new_occupation.vocabulary_file,\n        new_vocab_size=new_occupation.vocabulary_size,\n        num_oov_buckets=new_occupation.num_oov_buckets,\n        old_vocab=occupation.vocabulary_file,\n        old_vocab_size=occupation.vocabulary_size,\n        # Can't use constant_initializer with load_and_remap.  In practice,\n        # use a truncated normal initializer.\n        backup_initializer=tf.compat.v1.initializers.random_uniform(\n            minval=0.39, maxval=0.39))\n    warm_started_linear_classifier = self._linear_classifier_fn(\n        feature_columns=[occupation],\n        n_classes=4,\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=linear_classifier.model_dir,\n            var_name_to_vocab_info={\n                OCCUPATION_WEIGHT_NAME: occupation_vocab_info\n            },\n            # Explicitly providing None here will only warm-start variables\n            # referenced in var_name_to_vocab_info (the bias will not be\n            # warm-started).\n            vars_to_warm_start=None),\n        partitioner=partitioner)\n\n    warm_started_linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n    # 'doctor' was ID-0 and still ID-0.\n    self.assertAllClose(\n        linear_classifier.get_variable_value(OCCUPATION_WEIGHT_NAME)[0, :],\n        warm_started_linear_classifier.get_variable_value(\n            OCCUPATION_WEIGHT_NAME)[0, :])\n    # 'consultant' was ID-2 and now ID-1.\n    self.assertAllClose(\n        linear_classifier.get_variable_value(OCCUPATION_WEIGHT_NAME)[2, :],\n        warm_started_linear_classifier.get_variable_value(\n            OCCUPATION_WEIGHT_NAME)[1, :])\n    # 'engineer' is a new entry and should be initialized with the\n    # backup_initializer in VocabInfo.\n    self.assertAllClose([0.39] * 4,\n                        warm_started_linear_classifier.get_variable_value(\n                            OCCUPATION_WEIGHT_NAME)[2, :])\n    # Bias should still be zero (from initialization logic).\n    self.assertAllClose(\n        [0.0] * 4, warm_started_linear_classifier.get_variable_value(BIAS_NAME))\n\n  def test_warm_starting_with_naming_change(self):\n    \"\"\"Tests warm-starting with a Tensor name remapping.\"\"\"\n    age_in_years = self._fc_lib.numeric_column('age_in_years')\n\n    # Create a LinearClassifier and train to save a checkpoint.\n    linear_classifier = self._linear_classifier_fn(\n        feature_columns=[age_in_years],\n        model_dir=self._ckpt_and_vocab_dir,\n        n_classes=4,\n        optimizer='SGD')\n    linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n\n    # Create a second LinearClassifier, warm-started from the first.  Use a\n    # learning_rate = 0.0 optimizer to check values (use SGD so we don't have\n    # accumulator values that change).\n    warm_started_linear_classifier = self._linear_classifier_fn(\n        feature_columns=[self._fc_lib.numeric_column('age')],\n        n_classes=4,\n        optimizer=tf.compat.v1.train.GradientDescentOptimizer(\n            learning_rate=0.0),\n        # The 'age' variable correspond to the 'age_in_years' variable in the\n        # previous model.\n        warm_start_from=estimator.WarmStartSettings(\n            ckpt_to_initialize_from=linear_classifier.model_dir,\n            var_name_to_prev_var_name={\n                AGE_WEIGHT_NAME: AGE_WEIGHT_NAME.replace('age', 'age_in_years')\n            }))\n\n    warm_started_linear_classifier.train(input_fn=self._input_fn, max_steps=1)\n    self.assertAllClose(\n        linear_classifier.get_variable_value(\n            AGE_WEIGHT_NAME.replace('age', 'age_in_years')),\n        warm_started_linear_classifier.get_variable_value(AGE_WEIGHT_NAME))\n    # The bias is also warm-started (with no name remapping).\n    self.assertAllClose(\n        linear_classifier.get_variable_value(BIAS_NAME),\n        warm_started_linear_classifier.get_variable_value(BIAS_NAME))\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/distribute_strategy_estimator_integration_test.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests that show that DistributionStrategy works with canned Estimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport shutil\nimport tempfile\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator import run_config\nfrom tensorflow_estimator.python.estimator import training\nfrom tensorflow_estimator.python.estimator.canned import dnn_linear_combined\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.export import export_lib as export\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n\nclass DNNLinearCombinedClassifierIntegrationTest(tf.test.TestCase,\n                                                 parameterized.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def dataset_input_fn(self, x, y, batch_size, shuffle):\n\n    def input_fn():\n      dataset = tf.compat.v1.data.Dataset.from_tensor_slices((x, y))\n      if shuffle:\n        dataset = dataset.shuffle(batch_size)\n      dataset = dataset.repeat(10).batch(batch_size)\n      return dataset\n\n    return input_fn\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          mode=['graph'],\n          distribution=[\n              tf.compat.v2.__internal__.distribute.combinations.one_device_strategy,\n              tf.compat.v2.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu,\n              tf.compat.v2.__internal__.distribute.combinations.mirrored_strategy_with_two_gpus\n          ],\n          use_train_and_evaluate=[True, False]))\n  def test_estimator_with_strategy_hooks(self, distribution,\n                                         use_train_and_evaluate):\n    config = run_config.RunConfig(eval_distribute=distribution)\n\n    def _input_map_fn(tensor):\n      return {'feature': tensor}, tensor\n\n    def input_fn():\n      return tf.data.Dataset.from_tensors(\n          [1.]).repeat(10).batch(5).map(_input_map_fn)\n\n    def model_fn(features, labels, mode):\n      del features, labels\n      global_step = tf.compat.v1.train.get_global_step()\n      if mode == model_fn_lib.ModeKeys.TRAIN:\n        train_hook1 = tf.compat.v1.train.StepCounterHook(\n            every_n_steps=1, output_dir=self.get_temp_dir())\n        train_hook2 = tf.compat.v1.test.mock.MagicMock(\n            wraps=tf.compat.v1.train.SessionRunHook(),\n            spec=tf.compat.v1.train.SessionRunHook)\n        return model_fn_lib.EstimatorSpec(\n            mode,\n            loss=tf.constant(1.),\n            train_op=global_step.assign_add(1),\n            training_hooks=[train_hook1, train_hook2])\n      if mode == model_fn_lib.ModeKeys.EVAL:\n        eval_hook1 = tf.compat.v1.train.StepCounterHook(\n            every_n_steps=1, output_dir=self.get_temp_dir())\n        eval_hook2 = tf.compat.v1.test.mock.MagicMock(\n            wraps=tf.compat.v1.train.SessionRunHook(),\n            spec=tf.compat.v1.train.SessionRunHook)\n        return model_fn_lib.EstimatorSpec(\n            mode=mode,\n            loss=tf.constant(1.),\n            evaluation_hooks=[eval_hook1, eval_hook2])\n    num_steps = 10\n    estimator = estimator_lib.EstimatorV2(\n        model_fn=model_fn, model_dir=self.get_temp_dir(), config=config)\n    if use_train_and_evaluate:\n      training.train_and_evaluate(\n          estimator, training.TrainSpec(input_fn, max_steps=num_steps),\n          training.EvalSpec(input_fn))\n    else:\n      estimator.train(input_fn, steps=num_steps)\n      estimator.evaluate(input_fn, steps=num_steps)\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          mode=['graph'],\n          distribution=[\n              tf.compat.v2.__internal__.distribute.combinations.one_device_strategy,\n              tf.compat.v2.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu,\n              tf.compat.v2.__internal__.distribute.combinations.mirrored_strategy_with_two_gpus\n          ],\n          use_train_and_evaluate=[True, False]))\n  def test_complete_flow_with_mode(self, distribution, use_train_and_evaluate):\n    label_dimension = 2\n    input_dimension = label_dimension\n    batch_size = 10\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    train_input_fn = self.dataset_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size // distribution.num_replicas_in_sync,\n        shuffle=True)\n    eval_input_fn = self.dataset_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size // distribution.num_replicas_in_sync,\n        shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    linear_feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n    estimator = dnn_linear_combined.DNNLinearCombinedRegressor(\n        linear_feature_columns=linear_feature_columns,\n        dnn_hidden_units=(2, 2),\n        dnn_feature_columns=dnn_feature_columns,\n        label_dimension=label_dimension,\n        model_dir=self._model_dir,\n        # TODO(isaprykin): Work around the colocate_with error.\n        dnn_optimizer='Adagrad',\n        linear_optimizer='Adagrad',\n        config=run_config.RunConfig(\n            train_distribute=distribution, eval_distribute=distribution))\n\n    num_steps = 10\n    if use_train_and_evaluate:\n      scores, _ = training.train_and_evaluate(\n          estimator, training.TrainSpec(train_input_fn, max_steps=num_steps),\n          training.EvalSpec(eval_input_fn))\n    else:\n      estimator.train(train_input_fn, steps=num_steps)\n      scores = estimator.evaluate(eval_input_fn)\n\n    self.assertEqual(num_steps, scores[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', scores)\n\n    predictions = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in estimator.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, label_dimension), predictions.shape)\n\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(\n        feature_columns)\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    export_dir = estimator.export_saved_model(tempfile.mkdtemp(),\n                                              serving_input_receiver_fn)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n  def tearDown(self):\n    if self._model_dir:\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._model_dir)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/distribute_strategy_estimator_training_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests that show Distribute Coordinator works with Estimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport copy\nimport glob\nimport json\nimport os\nimport sys\nimport tempfile\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.distribute import distribute_coordinator as dc\nfrom tensorflow.python.distribute import estimator_training as dc_training\nfrom tensorflow.python.distribute import multi_worker_test_base\nfrom tensorflow.python.distribute import multi_worker_util\nfrom tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver\nfrom tensorflow.python.distribute.distribute_config import DistributeConfig\nfrom tensorflow.python.eager import context\nfrom tensorflow_estimator.python.estimator import exporter as exporter_lib\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\nfrom tensorflow_estimator.python.estimator import training as estimator_training\nfrom tensorflow_estimator.python.estimator.canned import dnn_linear_combined\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.export import export as export_lib\n\nBATCH_SIZE = 10\nLABEL_DIMENSION = 2\nDATA = np.linspace(\n    0., 2., BATCH_SIZE * LABEL_DIMENSION,\n    dtype=np.float32).reshape(BATCH_SIZE, LABEL_DIMENSION)\nEVAL_NAME = \"foo\"\nEXPORTER_NAME = \"saved_model_exporter\"\nMAX_STEPS = 10\n\nCHIEF = dc._TaskType.CHIEF\nEVALUATOR = dc._TaskType.EVALUATOR\nWORKER = dc._TaskType.WORKER\nPS = dc._TaskType.PS\n\noriginal_run_std_server = dc._run_std_server\n\n\nclass DistributeCoordinatorIntegrationTest(\n    multi_worker_test_base.IndependentWorkerTestBase, parameterized.TestCase):\n\n  @classmethod\n  def setUpClass(cls):\n    \"\"\"Create a local cluster with 2 workers.\"\"\"\n    super(DistributeCoordinatorIntegrationTest, cls).setUpClass()\n    cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(\n        num_workers=3, num_ps=2, has_eval=True)\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n    super(DistributeCoordinatorIntegrationTest, self).setUp()\n\n  def dataset_input_fn(self, x, y, batch_size, shuffle):\n\n    def input_fn():\n      dataset = tf.compat.v1.data.Dataset.from_tensor_slices((x, y))\n      if shuffle:\n        dataset = dataset.shuffle(batch_size)\n      dataset = dataset.repeat(100).batch(batch_size)\n      return dataset\n\n    return input_fn\n\n  def _get_exporter(self, name, fc):\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(fc)\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n    return exporter_lib.LatestExporter(\n        name, serving_input_receiver_fn=serving_input_receiver_fn)\n\n  def _extract_loss_and_global_step(self, event_folder):\n    \"\"\"Returns the loss and global step in last event.\"\"\"\n    event_paths = glob.glob(os.path.join(event_folder, \"events*\"))\n    self.assertNotEmpty(\n        event_paths, msg=\"Event file not found in dir %s\" % event_folder)\n\n    loss = None\n    global_step_count = None\n\n    for e in tf.compat.v1.train.summary_iterator(event_paths[-1]):\n      current_loss = None\n      for v in e.summary.value:\n        if v.tag == \"loss\":\n          current_loss = v.simple_value\n\n      # If loss is not found, global step is meaningless.\n      if current_loss is None:\n        continue\n\n      current_global_step = e.step\n      if global_step_count is None or current_global_step > global_step_count:\n        global_step_count = current_global_step\n        loss = current_loss\n\n    return (loss, global_step_count)\n\n  def _get_estimator(self,\n                     train_distribute,\n                     eval_distribute,\n                     remote_cluster=None):\n    input_dimension = LABEL_DIMENSION\n    linear_feature_columns = [\n        tf.compat.v1.feature_column.numeric_column(\"x\", shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        tf.compat.v1.feature_column.numeric_column(\"x\", shape=(input_dimension,))\n    ]\n\n    return dnn_linear_combined.DNNLinearCombinedRegressor(\n        linear_feature_columns=linear_feature_columns,\n        dnn_hidden_units=(2, 2),\n        dnn_feature_columns=dnn_feature_columns,\n        label_dimension=LABEL_DIMENSION,\n        model_dir=self._model_dir,\n        dnn_optimizer=\"Adagrad\",\n        linear_optimizer=\"Adagrad\",\n        config=run_config_lib.RunConfig(\n            experimental_distribute=DistributeConfig(\n                train_distribute=train_distribute,\n                eval_distribute=eval_distribute,\n                remote_cluster=remote_cluster)))\n\n  def _complete_flow(self,\n                     train_distribute,\n                     eval_distribute,\n                     remote_cluster=None,\n                     use_train_and_evaluate=True):\n    estimator = self._get_estimator(train_distribute, eval_distribute,\n                                    remote_cluster)\n\n    input_dimension = LABEL_DIMENSION\n    train_input_fn = self.dataset_input_fn(\n        x={\"x\": DATA},\n        y=DATA,\n        batch_size=BATCH_SIZE // train_distribute.num_replicas_in_sync,\n        shuffle=True)\n    if eval_distribute:\n      eval_batch_size = BATCH_SIZE // eval_distribute.num_replicas_in_sync\n    else:\n      eval_batch_size = BATCH_SIZE\n    eval_input_fn = self.dataset_input_fn(\n        x={\"x\": DATA}, y=DATA, batch_size=eval_batch_size, shuffle=False)\n\n    linear_feature_columns = [\n        tf.compat.v1.feature_column.numeric_column(\"x\", shape=(input_dimension,))\n    ]\n    dnn_feature_columns = [\n        tf.compat.v1.feature_column.numeric_column(\"x\", shape=(input_dimension,))\n    ]\n    feature_columns = linear_feature_columns + dnn_feature_columns\n\n    eval_spec = estimator_training.EvalSpec(\n        name=EVAL_NAME,\n        input_fn=eval_input_fn,\n        steps=None,\n        exporters=self._get_exporter(EXPORTER_NAME, feature_columns),\n        start_delay_secs=0,\n        throttle_secs=1)\n\n    if use_train_and_evaluate:\n      estimator_training.train_and_evaluate(\n          estimator,\n          estimator_training.TrainSpec(train_input_fn, max_steps=MAX_STEPS),\n          eval_spec)\n    else:\n      estimator.train(train_input_fn, max_steps=MAX_STEPS)\n\n      latest_ckpt_path = estimator.latest_checkpoint()\n      metrics = estimator.evaluate(\n          eval_input_fn, checkpoint_path=latest_ckpt_path, name=EVAL_NAME)\n\n      # Export the eval result to files.\n      eval_result = estimator_training._EvalResult(\n          status=estimator_training._EvalStatus.EVALUATED,\n          metrics=metrics,\n          checkpoint_path=latest_ckpt_path)\n      evaluator = estimator_training._TrainingExecutor._Evaluator(\n          estimator, eval_spec, None)\n      evaluator._export_eval_result(eval_result, True)\n\n    return estimator\n\n  def _inspect_train_and_eval_events(self, estimator):\n    # Make sure nothing is stuck in limbo.\n    tf.compat.v1.summary.FileWriterCache.clear()\n\n    # Examine the training events. Use a range to check global step to avoid\n    # flakyness due to global step race condition.\n    training_loss, _ = self._extract_loss_and_global_step(self._model_dir)\n    self.assertIsNotNone(training_loss)\n\n    # Examine the eval events. The global step should be accurate.\n    eval_dir = os.path.join(self._model_dir, \"eval_\" + EVAL_NAME)\n    eval_loss, eval_global_step = self._extract_loss_and_global_step(\n        event_folder=eval_dir)\n    self.assertIsNotNone(eval_loss)\n    self.assertGreaterEqual(eval_global_step, MAX_STEPS)\n\n    # Examine the export folder.\n    export_dir = os.path.join(\n        os.path.join(self._model_dir, \"export\"), EXPORTER_NAME)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n    # Examine the ckpt for predict.\n    def predict_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensor_slices({\"x\": DATA}).batch(BATCH_SIZE)\n\n    predicted_proba = np.array([\n        x[prediction_keys.PredictionKeys.PREDICTIONS]\n        for x in estimator.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((BATCH_SIZE, LABEL_DIMENSION), predicted_proba.shape)\n\n  def _make_cross_device_ops(self, num_gpus_per_worker):\n    return tf.distribute.ReductionToOneDevice()\n\n  def _get_strategy_object(self,\n                           strategy_cls,\n                           cluster_spec=None,\n                           eval_strategy=False):\n    if strategy_cls == tf.compat.v1.distribute.MirroredStrategy:\n      if eval_strategy:\n        return strategy_cls()\n      else:\n        return strategy_cls(\n            cross_device_ops=self._make_cross_device_ops(\n                num_gpus_per_worker=context.num_gpus()))\n    elif (strategy_cls == tf.compat.v1.distribute.MirroredStrategy and not eval_strategy):\n      return strategy_cls(\n          num_gpus_per_worker=context.num_gpus(),\n          cross_device_ops=self._make_cross_device_ops(\n              num_gpus_per_worker=context.num_gpus()))\n    elif strategy_cls == tf.compat.v1.distribute.experimental.ParameterServerStrategy:\n      assert cluster_spec is not None\n      cluster_resolver = SimpleClusterResolver(\n          cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),\n          task_type=\"ps\",\n          task_id=0,\n          num_accelerators={\"GPU\": context.num_gpus()})\n      return strategy_cls(cluster_resolver)\n    elif strategy_cls == tf.compat.v1.distribute.experimental.CentralStorageStrategy:\n      return strategy_cls._from_num_gpus(context.num_gpus())\n    else:\n      return strategy_cls()\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          mode=[\"graph\"],\n          train_distribute_cls=[\n              tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy,\n              tf.compat.v1.distribute.MirroredStrategy,\n              tf.compat.v1.distribute.experimental.ParameterServerStrategy\n          ],\n          eval_distribute_cls=[\n              None,\n              tf.compat.v1.distribute.MirroredStrategy,\n              tf.compat.v2.distribute.experimental.CentralStorageStrategy,\n              tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy,\n          ],\n          required_gpus=[0, 1]))\n  def test_complete_flow_standalone_client(self, train_distribute_cls,\n                                           eval_distribute_cls):\n\n    cluster_spec = copy.deepcopy(self._cluster_spec)\n    if (train_distribute_cls !=\n        tf.compat.v1.distribute.experimental.ParameterServerStrategy):\n      cluster_spec.pop(\"ps\", None)\n\n    train_distribute = self._get_strategy_object(\n        train_distribute_cls, cluster_spec=cluster_spec)\n\n    if eval_distribute_cls:\n      eval_distribute = self._get_strategy_object(\n          eval_distribute_cls, eval_strategy=True)\n    else:\n      eval_distribute = None\n\n    estimator = self._complete_flow(train_distribute, eval_distribute,\n                                    cluster_spec)\n    self._inspect_train_and_eval_events(estimator)\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          mode=[\"graph\"],\n          eval_distribute_class=[\n              None,\n              tf.compat.v1.distribute.MirroredStrategy,\n              tf.compat.v2.distribute.experimental.CentralStorageStrategy,\n          ],\n          required_gpus=[0, 1]))\n  def test_complete_flow_standalone_client_collective_nccl(\n      self, eval_distribute_class):\n    train_distribute = (\n        tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy(\n            communication=tf.compat.v1.distribute.experimental.CollectiveCommunication\n            .NCCL))\n\n    if eval_distribute_class:\n      eval_distribute = self._get_strategy_object(\n          eval_distribute_class, eval_strategy=True)\n    else:\n      eval_distribute = None\n\n    cluster_spec = copy.deepcopy(self._cluster_spec)\n    cluster_spec.pop(\"ps\", None)\n    estimator = self._complete_flow(train_distribute, eval_distribute,\n                                    cluster_spec)\n    self._inspect_train_and_eval_events(estimator)\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          mode=[\"graph\"],\n          train_distribute_cls=[\n              tf.compat.v1.distribute.MirroredStrategy,\n          ],\n          eval_distribute_cls=[\n              None,\n              tf.compat.v1.distribute.MirroredStrategy,\n          ],\n          required_gpus=[0, 1]))\n  def test_estimator_standalone_client(self, train_distribute_cls,\n                                       eval_distribute_cls):\n    train_distribute = self._get_strategy_object(train_distribute_cls)\n\n    if eval_distribute_cls:\n      eval_distribute = self._get_strategy_object(eval_distribute_cls)\n    else:\n      eval_distribute = None\n\n    # We use the whole cluster for evaluation.\n    cluster = copy.deepcopy(self._cluster_spec)\n    cluster.pop(\"evaluator\", None)\n\n    estimator = self._complete_flow(\n        train_distribute,\n        eval_distribute,\n        remote_cluster=cluster,\n        use_train_and_evaluate=False)\n    self._inspect_train_and_eval_events(estimator)\n\n  def _mock_run_std_server(self, *args, **kwargs):\n    ret = original_run_std_server(*args, **kwargs)\n    # Wait for all std servers to be brought up in order to reduce the chance of\n    # remote sessions taking local ports that have been assigned to std servers.\n    self._barrier.wait()\n    return ret\n\n  def _independent_worker_fn(\n      self,\n      train_distribute,\n      eval_distribute,\n  ):\n    train_distribute = copy.deepcopy(train_distribute)\n    eval_distribute = copy.deepcopy(eval_distribute)\n    with tf.compat.v1.test.mock.patch.object(dc, \"_run_std_server\",\n                                   self._mock_run_std_server):\n      self._complete_flow(train_distribute, eval_distribute)\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          mode=[\"graph\"],\n          train_distribute_cls=[\n              tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy,\n              tf.compat.v1.distribute.experimental.ParameterServerStrategy,\n          ],\n          eval_distribute_cls=[\n              None,\n              tf.compat.v1.distribute.MirroredStrategy,\n              tf.compat.v2.distribute.experimental.CentralStorageStrategy,\n              tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy,\n          ],\n          required_gpus=[0, 1]))\n  def test_complete_flow_independent_worker_between_graph(\n      self, train_distribute_cls, eval_distribute_cls):\n    if (context.num_gpus() < 2 and eval_distribute_cls ==\n        tf.compat.v1.distribute.experimental.MultiWorkerMirroredStrategy):\n      self.skipTest(\"`CollectiveAllReduceStrategy` needs at least two towers.\")\n\n    if (train_distribute_cls ==\n        tf.compat.v1.distribute.experimental.ParameterServerStrategy):\n      cluster_spec = tf.compat.v2.__internal__.distribute.multi_process_runner.create_cluster_spec(\n          num_workers=3, num_ps=2, has_eval=True)\n      # 3 workers, 2 ps.\n      self._barrier = dc._Barrier(5)\n    else:\n      cluster_spec = tf.compat.v2.__internal__.distribute.multi_process_runner.create_cluster_spec(\n          num_workers=3, num_ps=0, has_eval=True)\n      # 3 workers.\n      self._barrier = dc._Barrier(3)\n\n    train_distribute = self._get_strategy_object(\n        train_distribute_cls, cluster_spec=cluster_spec)\n\n    if eval_distribute_cls:\n      eval_distribute = self._get_strategy_object(\n          eval_distribute_cls, eval_strategy=True)\n    else:\n      eval_distribute = None\n\n    threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,\n                                                 cluster_spec, train_distribute,\n                                                 eval_distribute)\n    threads_to_join = []\n    for task_type, ts in threads.items():\n      if task_type == PS:\n        continue\n      for t in ts:\n        threads_to_join.append(t)\n    self.join_independent_workers(threads_to_join)\n\n    estimator = self._get_estimator(train_distribute, eval_distribute)\n    self._inspect_train_and_eval_events(estimator)\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          mode=[\"graph\"],\n          train_distribute_cls=[\n              tf.compat.v1.distribute.MirroredStrategy,\n          ],\n          eval_distribute_cls=[\n              None,\n              tf.compat.v1.distribute.MirroredStrategy,\n          ],\n          required_gpus=[0, 1]))\n  def test_complete_flow_independent_worker_in_graph(self, train_distribute_cls,\n                                                     eval_distribute_cls):\n    train_distribute = self._get_strategy_object(train_distribute_cls)\n\n    if eval_distribute_cls:\n      eval_distribute = self._get_strategy_object(\n          eval_distribute_cls, eval_strategy=True)\n    else:\n      eval_distribute = None\n\n    cluster_spec = tf.compat.v2.__internal__.distribute.multi_process_runner.create_cluster_spec(\n        num_workers=3, num_ps=0, has_eval=True)\n    # 3 workers.\n    self._barrier = dc._Barrier(3)\n    threads = self.run_multiple_tasks_in_threads(self._independent_worker_fn,\n                                                 cluster_spec, train_distribute,\n                                                 eval_distribute)\n    self.join_independent_workers([threads[WORKER][0], threads[EVALUATOR][0]])\n\n    estimator = self._get_estimator(train_distribute, eval_distribute)\n    self._inspect_train_and_eval_events(estimator)\n\n\nTF_CONFIG_WITH_CHIEF = {\n    \"cluster\": {\n        \"chief\": [\"fake_chief\"],\n    },\n    \"task\": {\n        \"type\": \"chief\",\n        \"index\": 0\n    }\n}\n\nTF_CONFIG_WITH_MASTER = {\n    \"cluster\": {\n        \"master\": [\"fake_master\"],\n    },\n    \"task\": {\n        \"type\": \"master\",\n        \"index\": 0\n    }\n}\n\nTF_CONFIG_WITHOUT_TASK = {\"cluster\": {\"chief\": [\"fake_worker\"]}}\n\n\nclass RunConfigTest(tf.compat.v1.test.TestCase):\n\n  def test_previously_unexpected_cluster_spec(self):\n    with tf.compat.v1.test.mock.patch.dict(\n        \"os.environ\", {\"TF_CONFIG\": json.dumps(TF_CONFIG_WITHOUT_TASK)}):\n      run_config_lib.RunConfig(\n          experimental_distribute=DistributeConfig(\n              train_distribute=tf.compat.v1.distribute.MirroredStrategy(\n                  [\"/device:GPU:0\", \"/device:GPU:1\"])))\n\n  def test_should_run_distribute_coordinator(self):\n    \"\"\"Tests that should_run_distribute_coordinator return a correct value.\"\"\"\n    # We don't use distribute coordinator for local training.\n    self.assertFalse(\n        dc_training.should_run_distribute_coordinator(\n            run_config_lib.RunConfig()))\n\n    # When `train_distribute` is not specified, don't use distribute\n    # coordinator.\n    with tf.compat.v1.test.mock.patch.dict(\n        \"os.environ\", {\"TF_CONFIG\": json.dumps(TF_CONFIG_WITH_CHIEF)}):\n      self.assertFalse(\n          dc_training.should_run_distribute_coordinator(\n              run_config_lib.RunConfig()))\n\n    # When `train_distribute` is specified and TF_CONFIG is detected, use\n    # distribute coordinator.\n    with tf.compat.v1.test.mock.patch.dict(\n        \"os.environ\", {\"TF_CONFIG\": json.dumps(TF_CONFIG_WITH_CHIEF)}):\n      config_with_train_distribute = run_config_lib.RunConfig(\n          experimental_distribute=DistributeConfig(\n              train_distribute=tf.compat.v1.distribute.MirroredStrategy(\n                  [\"/device:GPU:0\", \"/device:GPU:1\"])))\n      config_with_eval_distribute = run_config_lib.RunConfig(\n          experimental_distribute=DistributeConfig(\n              eval_distribute=tf.compat.v1.distribute.MirroredStrategy(\n                  [\"/device:GPU:0\", \"/device:GPU:1\"])))\n    self.assertTrue(\n        dc_training.should_run_distribute_coordinator(\n            config_with_train_distribute))\n    self.assertFalse(\n        dc_training.should_run_distribute_coordinator(\n            config_with_eval_distribute))\n\n    # With a master in the cluster, don't run distribute coordinator.\n    with tf.compat.v1.test.mock.patch.dict(\n        \"os.environ\", {\"TF_CONFIG\": json.dumps(TF_CONFIG_WITH_MASTER)}):\n      config = run_config_lib.RunConfig(\n          experimental_distribute=DistributeConfig(\n              train_distribute=tf.compat.v1.distribute.MirroredStrategy(\n                  [\"/device:GPU:0\", \"/device:GPU:1\"])))\n    self.assertFalse(dc_training.should_run_distribute_coordinator(config))\n\n  def test_init_run_config_duplicate_distribute(self):\n    with self.assertRaises(ValueError):\n      run_config_lib.RunConfig(\n          train_distribute=tf.compat.v1.distribute.MirroredStrategy(),\n          experimental_distribute=DistributeConfig(\n              train_distribute=tf.compat.v1.distribute.MirroredStrategy()))\n\n    with self.assertRaises(ValueError):\n      run_config_lib.RunConfig(\n          eval_distribute=tf.compat.v1.distribute.MirroredStrategy(),\n          experimental_distribute=DistributeConfig(\n              eval_distribute=tf.compat.v1.distribute.MirroredStrategy()))\n\n  def test_init_run_config_none_distribute_coordinator_mode(self):\n    # We don't use distribute coordinator for local training.\n    config = run_config_lib.RunConfig(\n        train_distribute=tf.compat.v1.distribute.MirroredStrategy())\n    dc_training.init_run_config(config, {})\n    self.assertIsNone(config._distribute_coordinator_mode)\n\n    # With a master in the cluster, don't run distribute coordinator.\n    with tf.compat.v1.test.mock.patch.dict(\n        \"os.environ\", {\"TF_CONFIG\": json.dumps(TF_CONFIG_WITH_MASTER)}):\n      config = run_config_lib.RunConfig(\n          train_distribute=tf.compat.v1.distribute.MirroredStrategy())\n      self.assertIsNone(config._distribute_coordinator_mode)\n\n    # When `train_distribute` is not specified, don't use distribute\n    # coordinator.\n    with tf.compat.v1.test.mock.patch.dict(\n        \"os.environ\", {\"TF_CONFIG\": json.dumps(TF_CONFIG_WITH_CHIEF)}):\n      config = run_config_lib.RunConfig()\n      self.assertFalse(hasattr(config, \"_distribute_coordinator_mode\"))\n\n  def test_init_run_config_independent_worker(self):\n    # When `train_distribute` is specified and TF_CONFIG is detected, use\n    # distribute coordinator with INDEPENDENT_WORKER mode.\n    with tf.compat.v1.test.mock.patch.dict(\n        \"os.environ\", {\"TF_CONFIG\": json.dumps(TF_CONFIG_WITH_CHIEF)}):\n      config = run_config_lib.RunConfig(\n          train_distribute=tf.compat.v1.distribute.MirroredStrategy())\n    self.assertEqual(config._distribute_coordinator_mode,\n                     dc.CoordinatorMode.INDEPENDENT_WORKER)\n\n  def test_init_run_config_standalone_client(self):\n    # When `train_distribute` is specified, TF_CONFIG is detected and\n    # `experimental.remote_cluster` is set use distribute coordinator with\n    # STANDALONE_CLIENT mode.\n    config = run_config_lib.RunConfig(\n        train_distribute=tf.compat.v1.distribute.MirroredStrategy(),\n        experimental_distribute=DistributeConfig(\n            remote_cluster={\"chief\": [\"fake_worker\"]}))\n    self.assertEqual(config._distribute_coordinator_mode,\n                     dc.CoordinatorMode.STANDALONE_CLIENT)\n\n\nif __name__ == \"__main__\":\n  # Reduce `recovery_wait_secs` from 30 seconds so the test completes quickly.\n  orig_init = tf.compat.v1.train.SessionManager.__init__\n\n  def new_init(*args, **kwargs):\n    kwargs.pop(\"recovery_wait_secs\", None)\n    kwargs[\"recovery_wait_secs\"] = 0.5\n    orig_init(*args, **kwargs)\n\n  tf.compat.v1.train.SessionManager.__init__ = new_init\n\n  with tf.compat.v1.test.mock.patch.object(sys, \"exit\", os._exit):\n    tf.compat.v1.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/early_stopping.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utilities for early stopping.\"\"\"\n\nimport collections\nimport operator\nimport os\n\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\n\n\n_EVENT_FILE_GLOB_PATTERN = 'events.out.tfevents.*'\n\n\n@estimator_export('estimator.experimental.make_early_stopping_hook')\ndef make_early_stopping_hook(estimator,\n                             should_stop_fn,\n                             run_every_secs=60,\n                             run_every_steps=None):\n  \"\"\"Creates early-stopping hook.\n\n  Returns a `SessionRunHook` that stops training when `should_stop_fn` returns\n  `True`.\n\n  Usage example:\n\n  ```python\n  estimator = ...\n  hook = early_stopping.make_early_stopping_hook(\n      estimator, should_stop_fn=make_stop_fn(...))\n  train_spec = tf.estimator.TrainSpec(..., hooks=[hook])\n  tf.estimator.train_and_evaluate(estimator, train_spec, ...)\n  ```\n\n  Caveat: Current implementation supports early-stopping both training and\n  evaluation in local mode. In distributed mode, training can be stopped but\n  evaluation (where it's a separate job) will indefinitely wait for new model\n  checkpoints to evaluate, so you will need other means to detect and stop it.\n  Early-stopping evaluation in distributed mode requires changes in\n  `train_and_evaluate` API and will be addressed in a future revision.\n\n  Args:\n    estimator: A `tf.estimator.Estimator` instance.\n    should_stop_fn: `callable`, function that takes no arguments and returns a\n      `bool`. If the function returns `True`, stopping will be initiated by the\n      chief.\n    run_every_secs: If specified, calls `should_stop_fn` at an interval of\n      `run_every_secs` seconds. Defaults to 60 seconds. Either this or\n      `run_every_steps` must be set.\n    run_every_steps: If specified, calls `should_stop_fn` every\n      `run_every_steps` steps. Either this or `run_every_secs` must be set.\n\n  Returns:\n    A `SessionRunHook` that periodically executes `should_stop_fn` and initiates\n    early stopping if the function returns `True`.\n\n  Raises:\n    TypeError: If `estimator` is not of type `tf.estimator.Estimator`.\n    ValueError: If both `run_every_secs` and `run_every_steps` are set.\n  \"\"\"\n  if not isinstance(estimator, estimator_lib.Estimator):\n    raise TypeError('`estimator` must have type `tf.estimator.Estimator`. '\n                    'Got: {}'.format(type(estimator)))\n\n  if run_every_secs is not None and run_every_steps is not None:\n    raise ValueError('Only one of `run_every_secs` and `run_every_steps` must '\n                     'be set.')\n\n  train_distribute = estimator.config.train_distribute\n  mwms = ['CollectiveAllReduceStrategy', 'MultiWorkerMirroredStrategy']\n  if train_distribute and (train_distribute.__class__.__name__.startswith(\n      strategy) for strategy in mwms):\n    if run_every_secs:\n      raise ValueError('run_every_secs should not be set when using '\n                       'MultiWorkerMirroredStrategy.')\n    return _MultiWorkerEarlyStoppingHook(should_stop_fn, run_every_steps)\n\n  if estimator.config.is_chief:\n    return _StopOnPredicateHook(should_stop_fn, run_every_secs, run_every_steps)\n  else:\n    return _CheckForStoppingHook()\n\n\n@estimator_export('estimator.experimental.stop_if_higher_hook')\ndef stop_if_higher_hook(estimator,\n                        metric_name,\n                        threshold,\n                        eval_dir=None,\n                        min_steps=0,\n                        run_every_secs=60,\n                        run_every_steps=None):\n  \"\"\"Creates hook to stop if the given metric is higher than the threshold.\n\n  Usage example:\n\n  ```python\n  estimator = ...\n  # Hook to stop training if accuracy becomes higher than 0.9.\n  hook = early_stopping.stop_if_higher_hook(estimator, \"accuracy\", 0.9)\n  train_spec = tf.estimator.TrainSpec(..., hooks=[hook])\n  tf.estimator.train_and_evaluate(estimator, train_spec, ...)\n  ```\n\n  Caveat: Current implementation supports early-stopping both training and\n  evaluation in local mode. In distributed mode, training can be stopped but\n  evaluation (where it's a separate job) will indefinitely wait for new model\n  checkpoints to evaluate, so you will need other means to detect and stop it.\n  Early-stopping evaluation in distributed mode requires changes in\n  `train_and_evaluate` API and will be addressed in a future revision.\n\n  Args:\n    estimator: A `tf.estimator.Estimator` instance.\n    metric_name: `str`, metric to track. \"loss\", \"accuracy\", etc.\n    threshold: Numeric threshold for the given metric.\n    eval_dir: If set, directory containing summary files with eval metrics. By\n      default, `estimator.eval_dir()` will be used.\n    min_steps: `int`, stop is never requested if global step is less than this\n      value. Defaults to 0.\n    run_every_secs: If specified, calls `should_stop_fn` at an interval of\n      `run_every_secs` seconds. Defaults to 60 seconds. Either this or\n      `run_every_steps` must be set.\n    run_every_steps: If specified, calls `should_stop_fn` every\n      `run_every_steps` steps. Either this or `run_every_secs` must be set.\n\n  Returns:\n    An early-stopping hook of type `SessionRunHook` that periodically checks\n    if the given metric is higher than specified threshold and initiates\n    early stopping if true.\n  \"\"\"\n  return _stop_if_threshold_crossed_hook(\n      estimator=estimator,\n      metric_name=metric_name,\n      threshold=threshold,\n      higher_is_better=True,\n      eval_dir=eval_dir,\n      min_steps=min_steps,\n      run_every_secs=run_every_secs,\n      run_every_steps=run_every_steps)\n\n\n@estimator_export('estimator.experimental.stop_if_lower_hook')\ndef stop_if_lower_hook(estimator,\n                       metric_name,\n                       threshold,\n                       eval_dir=None,\n                       min_steps=0,\n                       run_every_secs=60,\n                       run_every_steps=None):\n  \"\"\"Creates hook to stop if the given metric is lower than the threshold.\n\n  Usage example:\n\n  ```python\n  estimator = ...\n  # Hook to stop training if loss becomes lower than 100.\n  hook = early_stopping.stop_if_lower_hook(estimator, \"loss\", 100)\n  train_spec = tf.estimator.TrainSpec(..., hooks=[hook])\n  tf.estimator.train_and_evaluate(estimator, train_spec, ...)\n  ```\n\n  Caveat: Current implementation supports early-stopping both training and\n  evaluation in local mode. In distributed mode, training can be stopped but\n  evaluation (where it's a separate job) will indefinitely wait for new model\n  checkpoints to evaluate, so you will need other means to detect and stop it.\n  Early-stopping evaluation in distributed mode requires changes in\n  `train_and_evaluate` API and will be addressed in a future revision.\n\n  Args:\n    estimator: A `tf.estimator.Estimator` instance.\n    metric_name: `str`, metric to track. \"loss\", \"accuracy\", etc.\n    threshold: Numeric threshold for the given metric.\n    eval_dir: If set, directory containing summary files with eval metrics. By\n      default, `estimator.eval_dir()` will be used.\n    min_steps: `int`, stop is never requested if global step is less than this\n      value. Defaults to 0.\n    run_every_secs: If specified, calls `should_stop_fn` at an interval of\n      `run_every_secs` seconds. Defaults to 60 seconds. Either this or\n      `run_every_steps` must be set.\n    run_every_steps: If specified, calls `should_stop_fn` every\n      `run_every_steps` steps. Either this or `run_every_secs` must be set.\n\n  Returns:\n    An early-stopping hook of type `SessionRunHook` that periodically checks\n    if the given metric is lower than specified threshold and initiates\n    early stopping if true.\n  \"\"\"\n  return _stop_if_threshold_crossed_hook(\n      estimator=estimator,\n      metric_name=metric_name,\n      threshold=threshold,\n      higher_is_better=False,\n      eval_dir=eval_dir,\n      min_steps=min_steps,\n      run_every_secs=run_every_secs,\n      run_every_steps=run_every_steps)\n\n\n@estimator_export('estimator.experimental.stop_if_no_increase_hook')\ndef stop_if_no_increase_hook(estimator,\n                             metric_name,\n                             max_steps_without_increase,\n                             eval_dir=None,\n                             min_steps=0,\n                             run_every_secs=60,\n                             run_every_steps=None):\n  \"\"\"Creates hook to stop if metric does not increase within given max steps.\n\n  Usage example:\n\n  ```python\n  estimator = ...\n  # Hook to stop training if accuracy does not increase in over 100000 steps.\n  hook = early_stopping.stop_if_no_increase_hook(estimator, \"accuracy\", 100000)\n  train_spec = tf.estimator.TrainSpec(..., hooks=[hook])\n  tf.estimator.train_and_evaluate(estimator, train_spec, ...)\n  ```\n\n  Caveat: Current implementation supports early-stopping both training and\n  evaluation in local mode. In distributed mode, training can be stopped but\n  evaluation (where it's a separate job) will indefinitely wait for new model\n  checkpoints to evaluate, so you will need other means to detect and stop it.\n  Early-stopping evaluation in distributed mode requires changes in\n  `train_and_evaluate` API and will be addressed in a future revision.\n\n  Args:\n    estimator: A `tf.estimator.Estimator` instance.\n    metric_name: `str`, metric to track. \"loss\", \"accuracy\", etc.\n    max_steps_without_increase: `int`, maximum number of training steps with no\n      increase in the given metric.\n    eval_dir: If set, directory containing summary files with eval metrics. By\n      default, `estimator.eval_dir()` will be used.\n    min_steps: `int`, stop is never requested if global step is less than this\n      value. Defaults to 0.\n    run_every_secs: If specified, calls `should_stop_fn` at an interval of\n      `run_every_secs` seconds. Defaults to 60 seconds. Either this or\n      `run_every_steps` must be set.\n    run_every_steps: If specified, calls `should_stop_fn` every\n      `run_every_steps` steps. Either this or `run_every_secs` must be set.\n\n  Returns:\n    An early-stopping hook of type `SessionRunHook` that periodically checks\n    if the given metric shows no increase over given maximum number of\n    training steps, and initiates early stopping if true.\n  \"\"\"\n  return _stop_if_no_metric_improvement_hook(\n      estimator=estimator,\n      metric_name=metric_name,\n      max_steps_without_improvement=max_steps_without_increase,\n      higher_is_better=True,\n      eval_dir=eval_dir,\n      min_steps=min_steps,\n      run_every_secs=run_every_secs,\n      run_every_steps=run_every_steps)\n\n\n@estimator_export('estimator.experimental.stop_if_no_decrease_hook')\ndef stop_if_no_decrease_hook(estimator,\n                             metric_name,\n                             max_steps_without_decrease,\n                             eval_dir=None,\n                             min_steps=0,\n                             run_every_secs=60,\n                             run_every_steps=None):\n  \"\"\"Creates hook to stop if metric does not decrease within given max steps.\n\n  Usage example:\n\n  ```python\n  estimator = ...\n  # Hook to stop training if loss does not decrease in over 100000 steps.\n  hook = early_stopping.stop_if_no_decrease_hook(estimator, \"loss\", 100000)\n  train_spec = tf.estimator.TrainSpec(..., hooks=[hook])\n  tf.estimator.train_and_evaluate(estimator, train_spec, ...)\n  ```\n\n  Caveat: Current implementation supports early-stopping both training and\n  evaluation in local mode. In distributed mode, training can be stopped but\n  evaluation (where it's a separate job) will indefinitely wait for new model\n  checkpoints to evaluate, so you will need other means to detect and stop it.\n  Early-stopping evaluation in distributed mode requires changes in\n  `train_and_evaluate` API and will be addressed in a future revision.\n\n  Args:\n    estimator: A `tf.estimator.Estimator` instance.\n    metric_name: `str`, metric to track. \"loss\", \"accuracy\", etc.\n    max_steps_without_decrease: `int`, maximum number of training steps with no\n      decrease in the given metric.\n    eval_dir: If set, directory containing summary files with eval metrics. By\n      default, `estimator.eval_dir()` will be used.\n    min_steps: `int`, stop is never requested if global step is less than this\n      value. Defaults to 0.\n    run_every_secs: If specified, calls `should_stop_fn` at an interval of\n      `run_every_secs` seconds. Defaults to 60 seconds. Either this or\n      `run_every_steps` must be set.\n    run_every_steps: If specified, calls `should_stop_fn` every\n      `run_every_steps` steps. Either this or `run_every_secs` must be set.\n\n  Returns:\n    An early-stopping hook of type `SessionRunHook` that periodically checks\n    if the given metric shows no decrease over given maximum number of\n    training steps, and initiates early stopping if true.\n  \"\"\"\n  return _stop_if_no_metric_improvement_hook(\n      estimator=estimator,\n      metric_name=metric_name,\n      max_steps_without_improvement=max_steps_without_decrease,\n      higher_is_better=False,\n      eval_dir=eval_dir,\n      min_steps=min_steps,\n      run_every_secs=run_every_secs,\n      run_every_steps=run_every_steps)\n\n\ndef read_eval_metrics(eval_dir):\n  \"\"\"Helper to read eval metrics from eval summary files.\n\n  Args:\n    eval_dir: Directory containing summary files with eval metrics.\n\n  Returns:\n    A `dict` with global steps mapping to `dict` of metric names and values.\n  \"\"\"\n  eval_metrics_dict = collections.defaultdict(dict)\n  for event in _summaries(eval_dir):\n    if not event.HasField('summary'):\n      continue\n    metrics = {}\n    for value in event.summary.value:\n      if value.HasField('simple_value'):\n        metrics[value.tag] = value.simple_value\n    if metrics:\n      eval_metrics_dict[event.step].update(metrics)\n  return collections.OrderedDict(\n      sorted(eval_metrics_dict.items(), key=lambda t: t[0]))\n\n\ndef _stop_if_threshold_crossed_hook(estimator, metric_name, threshold,\n                                    higher_is_better, eval_dir, min_steps,\n                                    run_every_secs, run_every_steps):\n  \"\"\"Creates early-stopping hook to stop training if threshold is crossed.\"\"\"\n\n  if eval_dir is None:\n    eval_dir = estimator.eval_dir()\n\n  is_lhs_better = operator.gt if higher_is_better else operator.lt\n  greater_or_lesser = 'greater than' if higher_is_better else 'less than'\n\n  def stop_if_threshold_crossed_fn():\n    \"\"\"Returns `True` if the given metric crosses specified threshold.\"\"\"\n\n    eval_results = read_eval_metrics(eval_dir)\n\n    for step, metrics in eval_results.items():\n      if step < min_steps:\n        continue\n      val = metrics[metric_name]\n      if is_lhs_better(val, threshold):\n        tf.compat.v1.logging.info(\n            'At step %s, metric \"%s\" has value %s which is %s the configured '\n            'threshold (%s) for early stopping.', step, metric_name, val,\n            greater_or_lesser, threshold)\n        return True\n    return False\n\n  return make_early_stopping_hook(\n      estimator=estimator,\n      should_stop_fn=stop_if_threshold_crossed_fn,\n      run_every_secs=run_every_secs,\n      run_every_steps=run_every_steps)\n\n\ndef _stop_if_no_metric_improvement_hook(estimator, metric_name,\n                                        max_steps_without_improvement,\n                                        higher_is_better, eval_dir, min_steps,\n                                        run_every_secs, run_every_steps):\n  \"\"\"Returns hook to stop training if given metric shows no improvement.\"\"\"\n\n  if eval_dir is None:\n    eval_dir = estimator.eval_dir()\n\n  is_lhs_better = operator.gt if higher_is_better else operator.lt\n  increase_or_decrease = 'increase' if higher_is_better else 'decrease'\n\n  def stop_if_no_metric_improvement_fn():\n    \"\"\"Returns `True` if metric does not improve within max steps.\"\"\"\n\n    eval_results = read_eval_metrics(eval_dir)\n\n    best_val = None\n    best_val_step = None\n    for step, metrics in eval_results.items():\n      if step < min_steps:\n        continue\n      val = metrics[metric_name]\n      if best_val is None or is_lhs_better(val, best_val):\n        best_val = val\n        best_val_step = step\n      if step - best_val_step >= max_steps_without_improvement:\n        tf.compat.v1.logging.info(\n            'No %s in metric \"%s\" for %s steps, which is greater than or equal '\n            'to max steps (%s) configured for early stopping.',\n            increase_or_decrease, metric_name, step - best_val_step,\n            max_steps_without_improvement)\n        return True\n    return False\n\n  return make_early_stopping_hook(\n      estimator=estimator,\n      should_stop_fn=stop_if_no_metric_improvement_fn,\n      run_every_secs=run_every_secs,\n      run_every_steps=run_every_steps)\n\n\ndef _summaries(eval_dir):\n  \"\"\"Yields `tensorflow.Event` protos from event files in the eval dir.\n\n  Args:\n    eval_dir: Directory containing summary files with eval metrics.\n\n  Yields:\n    `tensorflow.Event` object read from the event files.\n  \"\"\"\n  if tf.compat.v1.gfile.Exists(eval_dir):\n    for event_file in tf.compat.v1.gfile.Glob(\n        os.path.join(eval_dir, _EVENT_FILE_GLOB_PATTERN)):\n      try:\n        for event in tf.compat.v1.train.summary_iterator(event_file):\n          yield event\n      except tf.errors.DataLossError as e:\n        # Upon DataLossError, we ignore the rest of the file and go to the next\n        # one.\n        tf.compat.v1.logging.warning(\n            'Skipping rest of the file due to encountering data corruption '\n            'error; file path: %s; original error raised by '\n            '`tf.train.summary_iterator`: %s', event_file, e)\n\n\ndef _get_or_create_stop_var():\n  with tf.compat.v1.variable_scope(\n      name_or_scope='signal_early_stopping',\n      values=[],\n      reuse=tf.compat.v1.AUTO_REUSE):\n    return tf.compat.v1.get_variable(\n        name='STOP',\n        shape=[],\n        dtype=tf.dtypes.bool,\n        initializer=tf.compat.v1.initializers.constant(False),\n        collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES],\n        trainable=False)\n\n\nclass _StopOnPredicateHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Hook that requests stop when `should_stop_fn` returns `True`.\"\"\"\n\n  def __init__(self, should_stop_fn, run_every_secs=60, run_every_steps=None):\n    if not callable(should_stop_fn):\n      raise TypeError('`should_stop_fn` must be callable.')\n\n    self._should_stop_fn = should_stop_fn\n    self._timer = tf.compat.v1.train.SecondOrStepTimer(\n        every_secs=run_every_secs, every_steps=run_every_steps)\n    self._global_step_tensor = None\n    self._stop_var = None\n    self._stop_op = None\n\n  def begin(self):\n    self._global_step_tensor = tf.compat.v1.train.get_global_step()\n    self._stop_var = _get_or_create_stop_var()\n    self._stop_op = tf.compat.v1.assign(self._stop_var, True)\n\n  def before_run(self, run_context):\n    del run_context\n    return tf.compat.v1.train.SessionRunArgs(self._global_step_tensor)\n\n  def after_run(self, run_context, run_values):\n    global_step = run_values.results\n    if self._timer.should_trigger_for_step(global_step):\n      self._timer.update_last_triggered_step(global_step)\n      if self._should_stop_fn():\n        tf.compat.v1.logging.info('Requesting early stopping at global step %d',\n                                  global_step)\n        run_context.session.run(self._stop_op)\n        run_context.request_stop()\n\n\nclass _CheckForStoppingHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Hook that requests stop if stop is requested by `_StopOnPredicateHook`.\"\"\"\n\n  def __init__(self):\n    self._stop_var = None\n\n  def begin(self):\n    self._stop_var = _get_or_create_stop_var()\n\n  def before_run(self, run_context):\n    del run_context\n    return tf.compat.v1.train.SessionRunArgs(self._stop_var)\n\n  def after_run(self, run_context, run_values):\n    should_early_stop = run_values.results\n    if should_early_stop:\n      tf.compat.v1.logging.info('Early stopping requested, suspending run.')\n      run_context.request_stop()\n\n\nclass _MultiWorkerEarlyStoppingHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Hook that requests stop when `should_stop_fn` returns `True`.\"\"\"\n\n  def _get_or_create_stop_var_with_aggregation(self):\n    with tf.compat.v1.variable_scope(\n        name_or_scope='signal_early_stopping',\n        values=[],\n        reuse=tf.compat.v1.AUTO_REUSE):\n      return tf.compat.v1.get_variable(\n          name='STOP',\n          shape=[],\n          dtype=tf.dtypes.int32,\n          initializer=tf_keras_v1.initializers.constant(0),\n          collections=[tf.compat.v1.GraphKeys.GLOBAL_VARIABLES],\n          synchronization=tf.VariableSynchronization.ON_WRITE,\n          aggregation=tf.compat.v1.VariableAggregation.SUM,\n          trainable=False)\n\n  def __init__(self, should_stop_fn, run_every_steps=None):\n    if not callable(should_stop_fn):\n      raise TypeError('`should_stop_fn` must be callable.')\n\n    self._should_stop_fn = should_stop_fn\n    self._timer = tf.compat.v1.train.SecondOrStepTimer(\n        every_secs=None, every_steps=run_every_steps)\n    self._global_step_tensor = None\n    self._stop_var = None\n    self._stop_op = None\n    self._non_stop_op = None\n\n  def begin(self):\n    self._global_step_tensor = tf.compat.v1.train.get_global_step()\n    self._stop_var = self._get_or_create_stop_var_with_aggregation()\n    assert tf.distribute.in_cross_replica_context()\n\n    strategy = tf.distribute.get_strategy()\n    self._stop_placeholder = None\n\n    def stop_op_fn(var):\n      placeholder = tf.compat.v1.placeholder_with_default(\n          0, tuple(), name='stop_value')\n      if self._stop_placeholder is None:\n        self._stop_placeholder = placeholder\n      return var.assign_add(placeholder)\n\n    self._stop_op = strategy.run(\n        stop_op_fn, args=(self._stop_var,))\n\n  def before_run(self, run_context):\n    del run_context\n    return tf.compat.v1.train.SessionRunArgs({\n        'global_step': self._global_step_tensor,\n        'stop_var': self._stop_var\n    })\n\n  def after_run(self, run_context, run_values):\n    global_step = run_values.results['global_step']\n    should_early_stop = run_values.results['stop_var']\n\n    if should_early_stop > 0:\n      tf.compat.v1.logging.info('Early stopping requested, suspending run.')\n      run_context.request_stop()\n      return\n    if self._timer.should_trigger_for_step(global_step):\n      self._timer.update_last_triggered_step(global_step)\n      if self._should_stop_fn():\n        run_context.session.run(\n            self._stop_op, feed_dict={self._stop_placeholder: 1})\n        tf.compat.v1.logging.info('Requesting early stopping at global step %d',\n                        global_step)\n      else:\n        run_context.session.run(\n            self._stop_op, feed_dict={self._stop_placeholder: 0})\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/early_stopping_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for early_stopping.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport tempfile\n\nfrom absl.testing import parameterized\nfrom absl.testing.absltest import mock\nimport tensorflow as tf\nfrom tensorflow.python.eager import context\nfrom tensorflow_estimator.python.estimator import early_stopping\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator import run_config\n\n\nclass _FakeRunConfig(run_config.RunConfig):\n\n  def __init__(self, is_chief):\n    super(_FakeRunConfig, self).__init__()\n    self._is_chief = is_chief\n\n  @property\n  def is_chief(self):\n    return self._is_chief\n\n\ndef _dummy_model_fn(features, labels, params):\n  _, _, _ = features, labels, params\n\n\nclass _FakeEstimator(estimator.Estimator):\n  \"\"\"Fake estimator for testing.\"\"\"\n\n  def __init__(self, config):\n    super(_FakeEstimator, self).__init__(\n        model_fn=_dummy_model_fn, config=config)\n\n\ndef _write_events(eval_dir, params):\n  \"\"\"Test helper to write events to summary files.\"\"\"\n  with context.graph_mode():\n    for steps, loss, accuracy in params:\n      estimator._write_dict_to_summary(eval_dir, {\n          'loss': loss,\n          'accuracy': accuracy,\n      }, steps)\n\n\nclass ReadEvalMetricsTest(tf.test.TestCase):\n\n  def test_read_eval_metrics(self):\n    eval_dir = tempfile.mkdtemp()\n    _write_events(\n        eval_dir,\n        [\n            # steps, loss, accuracy\n            (1000, 1, 2),\n            (2000, 3, 4),\n            (3000, 5, 6),\n        ])\n    self.assertEqual(\n        {\n            1000: {\n                'loss': 1,\n                'accuracy': 2\n            },\n            2000: {\n                'loss': 3,\n                'accuracy': 4\n            },\n            3000: {\n                'loss': 5,\n                'accuracy': 6\n            },\n        }, early_stopping.read_eval_metrics(eval_dir))\n\n  def test_data_loss_error_ignored(self):\n    eval_dir = tempfile.mkdtemp()\n    _write_events(\n        eval_dir,\n        [\n            # steps, loss, accuracy\n            (1000, 1, 2),\n            (2000, 3, 4),\n            (3000, 5, 6),\n        ])\n\n    orig_tf_train_summary_iterator = tf.compat.v1.train.summary_iterator\n\n    def _summary_iterator(*args, **kwargs):\n      for event in orig_tf_train_summary_iterator(*args, **kwargs):\n        yield event\n        # Raise an error for one of the files after yielding a summary event.\n        if event.HasField('summary'):\n          raise tf.errors.DataLossError(None, None, 'testing data loss')\n\n    with mock.patch.object(tf.compat.v1.train,\n                           'summary_iterator') as mock_summary_iterator:\n      mock_summary_iterator.side_effect = _summary_iterator\n      eval_results = early_stopping.read_eval_metrics(eval_dir)\n\n    self.assertEqual({\n        1000: {\n            'loss': 1,\n            'accuracy': 2\n        }\n    }, eval_results)\n\n  def test_read_eval_metrics_when_no_events(self):\n    eval_dir = tempfile.mkdtemp()\n    self.assertTrue(os.path.exists(eval_dir))\n\n    # No error should be raised when eval directory exists with no event files.\n    self.assertEqual({}, early_stopping.read_eval_metrics(eval_dir))\n\n    os.rmdir(eval_dir)\n    self.assertFalse(os.path.exists(eval_dir))\n\n    # No error should be raised when eval directory does not exist.\n    self.assertEqual({}, early_stopping.read_eval_metrics(eval_dir))\n\n\nclass EarlyStoppingHooksTest(tf.test.TestCase, parameterized.TestCase):\n\n  def setUp(self):\n    super(EarlyStoppingHooksTest, self).setUp()\n    config = _FakeRunConfig(is_chief=True)\n    self._estimator = _FakeEstimator(config=config)\n    eval_dir = self._estimator.eval_dir()\n    os.makedirs(eval_dir)\n    _write_events(\n        eval_dir,\n        [\n            # steps, loss, accuracy\n            (1000, 0.8, 0.5),\n            (2000, 0.7, 0.6),\n            (3000, 0.4, 0.7),\n            (3500, 0.41, 0.68),\n        ])\n\n  def run_session(self, hooks, should_stop):\n    hooks = hooks if isinstance(hooks, list) else [hooks]\n    with tf.Graph().as_default():\n      tf.compat.v1.train.create_global_step()\n      no_op = tf.no_op()\n      with tf.compat.v1.train.SingularMonitoredSession(hooks=hooks) as mon_sess:\n        mon_sess.run(no_op)\n        self.assertEqual(mon_sess.should_stop(), should_stop)\n\n  @parameterized.parameters((0.8, 0, False), (0.6, 4000, False), (0.6, 0, True))\n  def test_stop_if_higher_hook(self, threshold, min_steps, should_stop):\n    self.run_session(\n        early_stopping.stop_if_higher_hook(\n            self._estimator,\n            metric_name='accuracy',\n            threshold=threshold,\n            min_steps=min_steps), should_stop)\n\n  @parameterized.parameters((0.3, 0, False), (0.5, 4000, False), (0.5, 0, True))\n  def test_stop_if_lower_hook(self, threshold, min_steps, should_stop):\n    self.run_session(\n        early_stopping.stop_if_lower_hook(\n            self._estimator,\n            metric_name='loss',\n            threshold=threshold,\n            min_steps=min_steps), should_stop)\n\n  @parameterized.parameters((1500, 0, False), (500, 4000, False),\n                            (500, 0, True))\n  def test_stop_if_no_increase_hook(self, max_steps, min_steps, should_stop):\n    self.run_session(\n        early_stopping.stop_if_no_increase_hook(\n            self._estimator,\n            metric_name='accuracy',\n            max_steps_without_increase=max_steps,\n            min_steps=min_steps), should_stop)\n\n  @parameterized.parameters((1500, 0, False), (500, 4000, False),\n                            (500, 0, True))\n  def test_stop_if_no_decrease_hook(self, max_steps, min_steps, should_stop):\n    self.run_session(\n        early_stopping.stop_if_no_decrease_hook(\n            self._estimator,\n            metric_name='loss',\n            max_steps_without_decrease=max_steps,\n            min_steps=min_steps), should_stop)\n\n  @parameterized.parameters((1500, 0.3, False), (1500, 0.5, True),\n                            (500, 0.3, True))\n  def test_multiple_hooks(self, max_steps, loss_threshold, should_stop):\n    self.run_session([\n        early_stopping.stop_if_no_decrease_hook(\n            self._estimator,\n            metric_name='loss',\n            max_steps_without_decrease=max_steps),\n        early_stopping.stop_if_lower_hook(\n            self._estimator, metric_name='loss', threshold=loss_threshold)\n    ], should_stop)\n\n  @parameterized.parameters(False, True)\n  def test_make_early_stopping_hook(self, should_stop):\n    self.run_session([\n        early_stopping.make_early_stopping_hook(\n            self._estimator, should_stop_fn=lambda: should_stop)\n    ], should_stop)\n\n  def test_make_early_stopping_hook_typeerror(self):\n    with self.assertRaises(TypeError):\n      early_stopping.make_early_stopping_hook(\n          estimator=object(), should_stop_fn=lambda: True)\n\n  def test_make_early_stopping_hook_valueerror(self):\n    with self.assertRaises(ValueError):\n      early_stopping.make_early_stopping_hook(\n          self._estimator,\n          should_stop_fn=lambda: True,\n          run_every_secs=60,\n          run_every_steps=100)\n\n\nclass StopOnPredicateHookTest(tf.test.TestCase):\n\n  def test_stop(self):\n    hook = early_stopping._StopOnPredicateHook(\n        should_stop_fn=lambda: False, run_every_secs=0)\n    with tf.Graph().as_default():\n      tf.compat.v1.train.create_global_step()\n      no_op = tf.no_op()\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook]) as mon_sess:\n        mon_sess.run(no_op)\n        self.assertFalse(mon_sess.should_stop())\n        self.assertFalse(mon_sess.raw_session().run(hook._stop_var))\n\n    hook = early_stopping._StopOnPredicateHook(\n        should_stop_fn=lambda: True, run_every_secs=0)\n    with tf.Graph().as_default():\n      tf.compat.v1.train.create_global_step()\n      no_op = tf.no_op()\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook]) as mon_sess:\n        mon_sess.run(no_op)\n        self.assertTrue(mon_sess.should_stop())\n        self.assertTrue(mon_sess.raw_session().run(hook._stop_var))\n\n\nclass CheckForStoppingHookTest(tf.test.TestCase):\n\n  def test_stop(self):\n    hook = early_stopping._CheckForStoppingHook()\n    with tf.Graph().as_default():\n      no_op = tf.no_op()\n      assign_op = tf.compat.v1.assign(early_stopping._get_or_create_stop_var(),\n                                      True)\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook]) as mon_sess:\n        mon_sess.run(no_op)\n        self.assertFalse(mon_sess.should_stop())\n\n        mon_sess.run(assign_op)\n\n        # Because there are no guarantees that the stop variable will be read\n        # after the assign op is completed, run another no_op to ensure that the\n        # updated value is read.\n        if not mon_sess.should_stop():\n          mon_sess.run(no_op)\n          self.assertTrue(mon_sess.should_stop())\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/estimator.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Base Estimator class.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport copy\nimport os\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow as tf\n\nfrom google.protobuf import message  # pylint: disable=g-import-not-at-top\nfrom tensorflow.core.framework import summary_pb2\nfrom tensorflow.python.checkpoint import checkpoint as trackable_util\nfrom tensorflow.python.checkpoint import checkpoint_management\nfrom tensorflow.python.checkpoint import graph_view\nfrom tensorflow.python.distribute import estimator_training as distribute_coordinator_training\nfrom tensorflow.python.eager import context\nfrom tensorflow.python.eager import monitoring\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.profiler import trace\nfrom tensorflow.python.saved_model import path_helpers\nfrom tensorflow.python.summary import summary\nfrom tensorflow.python.training import basic_session_run_hooks\nfrom tensorflow.python.training import device_setter\nfrom tensorflow.python.training import evaluation\nfrom tensorflow.python.training import training\nfrom tensorflow.python.training import training_util\nfrom tensorflow.python.util import deprecation\nfrom tensorflow.python.util import function_utils\nfrom tensorflow.python.util import tf_contextlib\nfrom tensorflow.tools.docs import doc_controls\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator import run_config\nfrom tensorflow_estimator.python.estimator import util as estimator_util\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.export import export_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n_VALID_MODEL_FN_ARGS = set(\n    ['features', 'labels', 'mode', 'params', 'self', 'config'])\n_estimator_api_gauge = monitoring.BoolGauge('/tensorflow/api/estimator',\n                                            'estimator api usage', 'method')\n\n_canned_estimator_api_gauge = monitoring.StringGauge(\n    '/tensorflow/api/estimator/canned_estimator',\n    'Gauge to track the type of canned estimator used', 'ClassType')\n\n\n@estimator_export(v1=['estimator.Estimator'])\n@doc_controls.inheritable_header(\"\"\"\\\n  Warning: TensorFlow 2.15 included the final release of the `tf-estimator` \n  package. Estimators will not be available in TensorFlow 2.16 or after. See the\n  [migration guide](https://www.tensorflow.org/guide/migrate/migrating_estimator)\n  for more information about how to convert off of Estimators.\"\n  \"\"\")\nclass Estimator(object):\n  \"\"\"Estimator class to train and evaluate TensorFlow models.\n\n  The `Estimator` object wraps a model which is specified by a `model_fn`,\n  which, given inputs and a number of other parameters, returns the ops\n  necessary to perform training, evaluation, or predictions.\n\n  All outputs (checkpoints, event files, etc.) are written to `model_dir`, or a\n  subdirectory thereof. If `model_dir` is not set, a temporary directory is\n  used.\n\n  The `config` argument can be passed `tf.estimator.RunConfig` object containing\n  information about the execution environment. It is passed on to the\n  `model_fn`, if the `model_fn` has a parameter named \"config\" (and input\n  functions in the same manner). If the `config` parameter is not passed, it is\n  instantiated by the `Estimator`. Not passing config means that defaults useful\n  for local execution are used. `Estimator` makes config available to the model\n  (for instance, to allow specialization based on the number of workers\n  available), and also uses some of its fields to control internals, especially\n  regarding checkpointing.\n\n  The `params` argument contains hyperparameters. It is passed to the\n  `model_fn`, if the `model_fn` has a parameter named \"params\", and to the input\n  functions in the same manner. `Estimator` only passes params along, it does\n  not inspect it. The structure of `params` is therefore entirely up to the\n  developer.\n\n  None of `Estimator`'s methods can be overridden in subclasses (its\n  constructor enforces this). Subclasses should use `model_fn` to configure\n  the base class, and may add methods implementing specialized functionality.\n\n  See [estimators](https://tensorflow.org/guide/estimator) for more\n  information.\n\n  To warm-start an `Estimator`:\n\n  ```python\n  estimator = tf.estimator.DNNClassifier(\n      feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],\n      hidden_units=[1024, 512, 256],\n      warm_start_from=\"/path/to/checkpoint/dir\")\n  ```\n\n  For more details on warm-start configuration, see\n  `tf.estimator.WarmStartSettings`.\n\n  @compatibility(eager)\n  Calling methods of `Estimator` will work while eager execution is enabled.\n  However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator`\n  will switch to graph mode before calling all user-provided functions (incl.\n  hooks), so their code has to be compatible with graph mode execution. Note\n  that `input_fn` code using `tf.data` generally works in both graph and eager\n  modes.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               model_fn,\n               model_dir=None,\n               config=None,\n               params=None,\n               warm_start_from=None):\n    \"\"\"Constructs an `Estimator` instance.\n\n\n\n    Args:\n      model_fn: Model function. Follows the signature:\n        * `features` -- This is the first item returned from the `input_fn`\n        passed to `train`, `evaluate`, and `predict`. This should be a\n        single `tf.Tensor` or `dict` of same.\n        * `labels` -- This is the second item returned from the `input_fn`\n        passed to `train`, `evaluate`, and `predict`. This should be a\n        single `tf.Tensor` or `dict` of same (for multi-head models). If\n        mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will be\n        passed. If the `model_fn`'s signature does not accept `mode`, the\n        `model_fn` must still be able to handle `labels=None`.\n        * `mode` -- Optional. Specifies if this is training, evaluation or\n        prediction. See `tf.estimator.ModeKeys`.\n        `params` -- Optional `dict` of hyperparameters.  Will receive what is\n        passed to Estimator in `params` parameter. This allows to configure\n        Estimators from hyper parameter tuning.\n        * `config` -- Optional `estimator.RunConfig` object. Will receive what\n        is passed to Estimator as its `config` parameter, or a default\n        value. Allows setting up things in your `model_fn` based on\n        configuration such as `num_ps_replicas`, or `model_dir`.\n        * Returns -- `tf.estimator.EstimatorSpec`\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into an estimator to\n        continue training a previously saved model. If `PathLike` object, the\n        path will be resolved. If `None`, the model_dir in `config` will be used\n        if set. If both are set, they must be same. If both are `None`, a\n        temporary directory will be used.\n      config: `estimator.RunConfig` configuration object.\n      params: `dict` of hyper parameters that will be passed into `model_fn`.\n        Keys are names of parameters, values are basic python types.\n      warm_start_from: Optional string filepath to a checkpoint or SavedModel to\n        warm-start from, or a `tf.estimator.WarmStartSettings` object to fully\n        configure warm-starting.  If None, only TRAINABLE variables are\n        warm-started.  If the string filepath is provided instead of a\n        `tf.estimator.WarmStartSettings`, then all variables are warm-started,\n        and it is assumed that vocabularies and `tf.Tensor` names are unchanged.\n\n    Raises:\n      ValueError: parameters of `model_fn` don't match `params`.\n      ValueError: if this is called via a subclass and if that class overrides\n        a member of `Estimator`.\n    \"\"\"\n    _estimator_api_gauge.get_cell('init').set(True)\n    # We do not endorse Estimator child classes to override methods in\n    # Estimator, other than a select few. You're on your own if you cleverly\n    # override the method \"_assert_members_are_not_overridden\".\n    self.__class__._assert_members_are_not_overridden(self)  # pylint: disable=protected-access\n\n    self._config = maybe_overwrite_model_dir_and_session_config(\n        config, model_dir)\n\n    # The distribute field contains an instance of tf.distribute.Strategy.\n    self._train_distribution = self._config.train_distribute\n    self._eval_distribution = self._config.eval_distribute\n    # Model directory.\n    self._model_dir = self._config.model_dir\n    self._session_config = self._config.session_config\n    tf.compat.v1.logging.info('Using config: %s', str(vars(self._config)))\n\n    self._device_fn = (\n        self._config.device_fn or _get_replica_device_setter(self._config))\n\n    if model_fn is None:\n      raise ValueError('model_fn must be provided to Estimator.')\n    model_fn_lib.verify_model_fn_args(model_fn, params)\n    self._model_fn = model_fn\n    self._params = copy.deepcopy(params or {})\n\n    # pylint: disable=protected-access\n    self._warm_start_settings = _get_default_warm_start_settings(\n        warm_start_from)\n    # pylint: enable=protected-access\n\n  @property\n  def model_dir(self):\n    return self._model_dir\n\n  @property\n  def config(self):\n    return copy.deepcopy(self._config)\n\n  @property\n  def params(self):\n    return copy.deepcopy(self._params)\n\n  @property\n  def model_fn(self):\n    \"\"\"Returns the `model_fn` which is bound to `self.params`.\n\n    Returns:\n      The `model_fn` with following signature:\n        `def model_fn(features, labels, mode, config)`\n    \"\"\"\n\n    def public_model_fn(features, labels, mode, config):\n      return self._call_model_fn(features, labels, mode, config)\n\n    return public_model_fn\n\n  # TODO(ispir): support a list of names\n  def get_variable_value(self, name):\n    \"\"\"Returns value of the variable given by name.\n\n    Args:\n      name: string or a list of string, name of the tensor.\n\n    Returns:\n      Numpy array - value of the tensor.\n\n    Raises:\n      ValueError: If the `Estimator` has not produced a checkpoint yet.\n    \"\"\"\n    _check_checkpoint_available(self.model_dir)\n    with context.graph_mode():\n      return tf.train.load_variable(self.model_dir, name)\n\n  def get_variable_names(self):\n    \"\"\"Returns list of all variable names in this model.\n\n    Returns:\n      List of names.\n\n    Raises:\n      ValueError: If the `Estimator` has not produced a checkpoint yet.\n    \"\"\"\n    _check_checkpoint_available(self.model_dir)\n    with context.graph_mode():\n      return [name for name, _ in tf.train.list_variables(self.model_dir)]\n\n  def latest_checkpoint(self):\n    \"\"\"Finds the filename of the latest saved checkpoint file in `model_dir`.\n\n    Returns:\n      The full path to the latest checkpoint or `None` if no checkpoint was\n      found.\n    \"\"\"\n    with context.graph_mode():\n      return checkpoint_management.latest_checkpoint(self.model_dir)\n\n  def train(self,\n            input_fn,\n            hooks=None,\n            steps=None,\n            max_steps=None,\n            saving_listeners=None):\n    \"\"\"Trains a model given training data `input_fn`.\n\n    Args:\n      input_fn: A function that provides input data for training as minibatches.\n        See [Premade Estimators](\n        https://tensorflow.org/guide/premade_estimators#create_input_functions)\n          for more information. The function should construct and return one of\n        the following:\n          * A `tf.data.Dataset` object: Outputs of `Dataset` object must be a\n            tuple `(features, labels)` with same constraints as below.\n          * A tuple `(features, labels)`: Where `features` is a `tf.Tensor` or a\n            dictionary of string feature name to `Tensor` and `labels` is a\n            `Tensor` or a dictionary of string label name to `Tensor`. Both\n            `features` and `labels` are consumed by `model_fn`. They should\n            satisfy the expectation of `model_fn` from inputs.\n      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for\n        callbacks inside the training loop.\n      steps: Number of steps for which to train the model. If `None`, train\n        forever or train until `input_fn` generates the `tf.errors.OutOfRange`\n        error or `StopIteration` exception. `steps` works incrementally. If you\n        call two times `train(steps=10)` then training occurs in total 20 steps.\n        If `OutOfRange` or `StopIteration` occurs in the middle, training stops\n        before 20 steps. If you don't want to have incremental behavior please\n        set `max_steps` instead. If set, `max_steps` must be `None`.\n      max_steps: Number of total steps for which to train model. If `None`,\n        train forever or train until `input_fn` generates the\n        `tf.errors.OutOfRange` error or `StopIteration` exception. If set,\n        `steps` must be `None`. If `OutOfRange` or `StopIteration` occurs in the\n        middle, training stops before `max_steps` steps. Two calls to\n        `train(steps=100)` means 200 training iterations. On the other hand, two\n        calls to `train(max_steps=100)` means that the second call will not do\n        any iteration since first call did all 100 steps.\n      saving_listeners: list of `CheckpointSaverListener` objects. Used for\n        callbacks that run immediately before or after checkpoint savings.\n\n    Returns:\n      `self`, for chaining.\n\n    Raises:\n      ValueError: If both `steps` and `max_steps` are not `None`.\n      ValueError: If either `steps` or `max_steps <= 0`.\n    \"\"\"\n    _estimator_api_gauge.get_cell('train').set(True)\n    if self.config.task_type in (run_config.TaskType.EVALUATOR,\n                                 run_config.TaskType.PS):\n      raise ValueError(\n          'Train has been called wrong configuration. Please use '\n          'tf.estimator.train_and_evaluate which calls proper API according '\n          'to given configuration. Current configuration: {}.'.format(\n              self.config))\n\n    with context.graph_mode():\n      if (steps is not None) and (max_steps is not None):\n        raise ValueError('Can not provide both steps and max_steps.')\n      if steps is not None and steps <= 0:\n        raise ValueError('Must specify steps > 0, given: {}'.format(steps))\n      if max_steps is not None and max_steps <= 0:\n        raise ValueError(\n            'Must specify max_steps > 0, given: {}'.format(max_steps))\n\n      if max_steps is not None:\n        start_step = _load_global_step_from_checkpoint_dir(self._model_dir)\n        if max_steps <= start_step:\n          tf.compat.v1.logging.info(\n              'Skipping training since max_steps has already saved.'\n          )\n          return self\n\n      hooks = _check_hooks_type(hooks)\n      hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))\n\n      saving_listeners = _check_listeners_type(saving_listeners)\n      loss = self._train_model(input_fn, hooks, saving_listeners)\n      tf.compat.v1.logging.info('Loss for final step: %s.', loss)\n      return self\n\n  def _convert_train_steps_to_hooks(self, steps, max_steps):\n    \"\"\"Create hooks to run correct number of steps in training.\n\n    Args:\n      steps: number of steps to run during training.\n      max_steps: maximum number of steps to be run during training. It'll be the\n        maximum number of steps the model will train to after restoring from\n        checkpoint even across multiple estimator.train calls.\n\n    Returns:\n      List of hooks to be passed to the estimator.\n    \"\"\"\n    if steps is not None or max_steps is not None:\n      if self._train_distribution:\n        steps_per_run = getattr(self._train_distribution.extended,\n                                'steps_per_run', 1)\n        if steps_per_run > 1:\n          return [\n              basic_session_run_hooks._MultiStepStopAtStepHook(  # pylint: disable=protected-access\n                  steps, max_steps, steps_per_run)\n          ]\n      return [tf.compat.v1.train.StopAtStepHook(steps, max_steps)]\n    else:\n      return []\n\n  def eval_dir(self, name=None):\n    \"\"\"Shows the directory name where evaluation metrics are dumped.\n\n    Args:\n      name: Name of the evaluation if user needs to run multiple evaluations on\n        different data sets, such as on training data vs test data. Metrics for\n        different evaluations are saved in separate folders, and appear\n        separately in tensorboard.\n\n    Returns:\n      A string which is the path of directory contains evaluation metrics.\n    \"\"\"\n    return os.path.join(self._model_dir, 'eval' if not name else 'eval_' + name)\n\n  def evaluate(self,\n               input_fn,\n               steps=None,\n               hooks=None,\n               checkpoint_path=None,\n               name=None):\n    \"\"\"Evaluates the model given evaluation data `input_fn`.\n\n    For each step, calls `input_fn`, which returns one batch of data.\n    Evaluates until:\n    - `steps` batches are processed, or\n    - `input_fn` raises an end-of-input exception (`tf.errors.OutOfRangeError`\n    or `StopIteration`).\n\n    Args:\n      input_fn: A function that constructs the input data for evaluation. See\n        [Premade Estimators](\n        https://tensorflow.org/guide/premade_estimators#create_input_functions)\n        for more information. The function should construct and return one of\n        the following:\n        * A `tf.data.Dataset` object: Outputs of `Dataset` object must be a\n          tuple `(features, labels)` with same constraints as below.\n        * A tuple `(features, labels)`: Where `features` is a `tf.Tensor` or a\n          dictionary of string feature name to `Tensor` and `labels` is a\n          `Tensor` or a dictionary of string label name to `Tensor`. Both\n          `features` and `labels` are consumed by `model_fn`. They should\n          satisfy the expectation of `model_fn` from inputs.\n      steps: Number of steps for which to evaluate model. If `None`, evaluates\n        until `input_fn` raises an end-of-input exception.\n      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for\n        callbacks inside the evaluation call.\n      checkpoint_path: Path of a specific checkpoint to evaluate. If `None`, the\n        latest checkpoint in `model_dir` is used.  If there are no checkpoints\n        in `model_dir`, evaluation is run with newly initialized `Variables`\n        instead of ones restored from checkpoint.\n      name: Name of the evaluation if user needs to run multiple evaluations on\n        different data sets, such as on training data vs test data. Metrics for\n        different evaluations are saved in separate folders, and appear\n        separately in tensorboard.\n\n    Returns:\n      A dict containing the evaluation metrics specified in `model_fn` keyed by\n      name, as well as an entry `global_step` which contains the value of the\n      global step for which this evaluation was performed. For canned\n      estimators, the dict contains the `loss` (mean loss per mini-batch) and\n      the `average_loss` (mean loss per sample). Canned classifiers also return\n      the `accuracy`. Canned regressors also return the `label/mean` and the\n      `prediction/mean`.\n\n    Raises:\n      ValueError: If `steps <= 0`.\n    \"\"\"\n    _estimator_api_gauge.get_cell('evaluate').set(True)\n    # pylint: disable=protected-access\n    if (self._eval_distribution and\n        hasattr(self._config, '_distribute_coordinator_mode') and\n        self._config._distribute_coordinator_mode):\n      return distribute_coordinator_training.estimator_evaluate(\n          self,\n          lambda est, s, eval_hooks: est._actual_eval(  # pylint: disable=g-long-lambda\n              input_fn,\n              strategy=s,\n              steps=steps,\n              hooks=eval_hooks,\n              checkpoint_path=checkpoint_path,\n              name=name),\n          hooks)\n    # pylint: enable=protected-access\n    else:\n      return self._actual_eval(\n          input_fn,\n          strategy=self._eval_distribution,\n          steps=steps,\n          hooks=hooks,\n          checkpoint_path=checkpoint_path,\n          name=name)\n\n  def _actual_eval(self,\n                   input_fn,\n                   strategy=None,\n                   steps=None,\n                   hooks=None,\n                   checkpoint_path=None,\n                   name=None):\n    \"\"\"The method that does evaluation actually.\"\"\"\n    with context.graph_mode():\n      hooks = _check_hooks_type(hooks)\n      hooks.extend(self._convert_eval_steps_to_hooks(steps))\n\n      # Check that model has been trained (if nothing has been set explicitly).\n      if not checkpoint_path:\n        latest_path = checkpoint_management.latest_checkpoint(self._model_dir)\n        if not latest_path:\n          tf.compat.v1.logging.info(\n              'Could not find trained model in model_dir: {}, running '\n              'initialization to evaluate.'.format(self._model_dir))\n        checkpoint_path = latest_path\n\n      def _evaluate():\n        (scaffold, update_op, eval_dict, all_hooks) = (\n            self._evaluate_build_graph(input_fn, hooks, checkpoint_path))\n        return self._evaluate_run(\n            checkpoint_path=checkpoint_path,\n            scaffold=scaffold,\n            update_op=update_op,\n            eval_dict=eval_dict,\n            all_hooks=all_hooks,\n            output_dir=self.eval_dir(name))\n\n      with tf.Graph().as_default():\n        if strategy:\n          # We want to create the iterations variable outside the distribution\n          # scope as that is just stored on the host and mainly used to drive\n          # the loop and doesn't need to be a Mirrored/Device variable.\n          training.get_or_create_steps_per_run_variable()\n          with strategy.scope():\n            return _evaluate()\n        else:\n          return _evaluate()\n\n  def _convert_eval_steps_to_hooks(self, steps):\n    \"\"\"Create hooks to run correct number of steps in evaluation.\n\n    Args:\n      steps: number of steps to run during evaluation.\n\n    Raises:\n      ValueError: if steps is less than or equal to zero.\n\n    Returns:\n      List of hooks to be passed to the estimator.\n    \"\"\"\n    if steps is None:\n      return []\n\n    if steps <= 0:\n      raise ValueError('Must specify steps > 0, given: {}'.format(steps))\n\n    # The hooks are declared as private in evaluation.py discourage the use\n    # by other libraries or open source users. This should be the only usage\n    # of the estimator evaluation hooks.\n    if self._eval_distribution:\n      steps_per_run = getattr(self._eval_distribution.extended, 'steps_per_run',\n                              1)\n      if steps_per_run > 1:\n        return [\n            evaluation._MultiStepStopAfterNEvalsHook(  # pylint: disable=protected-access\n                num_evals=steps,\n                steps_per_run=steps_per_run)\n        ]\n    return [evaluation._StopAfterNEvalsHook(num_evals=steps)]  # pylint: disable=protected-access\n\n  def predict(self,\n              input_fn,\n              predict_keys=None,\n              hooks=None,\n              checkpoint_path=None,\n              yield_single_examples=True):\n    \"\"\"Yields predictions for given features.\n\n    Please note that interleaving two predict outputs does not work. See:\n    [issue/20506](\n    https://github.com/tensorflow/tensorflow/issues/20506#issuecomment-422208517)\n\n    Args:\n      input_fn: A function that constructs the features. Prediction continues\n        until `input_fn` raises an end-of-input exception\n        (`tf.errors.OutOfRangeError` or `StopIteration`). See [Premade\n        Estimators](\n        https://tensorflow.org/guide/premade_estimators#create_input_functions)\n        for more information. The function should construct and return one of\n        the following:\n        * `tf.data.Dataset` object -- Outputs of `Dataset` object must have\n          same constraints as below.\n        * features -- A `tf.Tensor` or a dictionary of string feature name to\n          `Tensor`. features are consumed by `model_fn`. They should satisfy\n          the expectation of `model_fn` from inputs.\n        * A tuple, in which case\n          the first item is extracted as features.\n      predict_keys: list of `str`, name of the keys to predict. It is used if\n        the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If\n        `predict_keys` is used then rest of the predictions will be filtered\n        from the dictionary. If `None`, returns all.\n      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for\n        callbacks inside the prediction call.\n      checkpoint_path: Path of a specific checkpoint to predict. If `None`, the\n        latest checkpoint in `model_dir` is used.  If there are no checkpoints\n        in `model_dir`, prediction is run with newly initialized `Variables`\n        instead of ones restored from checkpoint.\n      yield_single_examples: If `False`, yields the whole batch as returned by\n        the `model_fn` instead of decomposing the batch into individual\n        elements. This is useful if `model_fn` returns some tensors whose first\n        dimension is not equal to the batch size.\n\n    Yields:\n      Evaluated values of `predictions` tensors.\n\n    Raises:\n      ValueError: If batch length of predictions is not the same and\n        `yield_single_examples` is `True`.\n      ValueError: If there is a conflict between `predict_keys` and\n        `predictions`. For example if `predict_keys` is not `None` but\n        `tf.estimator.EstimatorSpec.predictions` is not a `dict`.\n    \"\"\"\n    _estimator_api_gauge.get_cell('predict').set(True)\n    with context.graph_mode():\n      hooks = _check_hooks_type(hooks)\n      # Check that model has been trained.\n      if not checkpoint_path:\n        checkpoint_path = checkpoint_management.latest_checkpoint(\n            self._model_dir)\n      if not checkpoint_path:\n        tf.compat.v1.logging.info(\n            'Could not find trained model in model_dir: {}, running '\n            'initialization to predict.'.format(self._model_dir))\n      with tf.Graph().as_default() as g:\n        tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)\n        self._create_and_assert_global_step(g)\n        features, input_hooks = self._get_features_from_input_fn(\n            input_fn, ModeKeys.PREDICT)\n        estimator_spec = self._call_model_fn(features, None, ModeKeys.PREDICT,\n                                             self.config)\n\n        # Call to warm_start has to be after model_fn is called.\n        self._maybe_warm_start(checkpoint_path)\n\n        predictions = self._extract_keys(estimator_spec.predictions,\n                                         predict_keys)\n        all_hooks = list(input_hooks)\n        all_hooks.extend(hooks)\n        all_hooks.extend(list(estimator_spec.prediction_hooks or []))\n        with tf.compat.v1.train.MonitoredSession(\n            session_creator=tf.compat.v1.train.ChiefSessionCreator(\n                checkpoint_filename_with_path=checkpoint_path,\n                master=self._config.master,\n                scaffold=estimator_spec.scaffold,\n                config=self._session_config),\n            hooks=all_hooks) as mon_sess:\n          while not mon_sess.should_stop():\n            preds_evaluated = mon_sess.run(predictions)\n            if not yield_single_examples:\n              yield preds_evaluated\n            elif not isinstance(predictions, dict):\n              for pred in preds_evaluated:\n                yield pred\n            else:\n              for i in range(self._extract_batch_length(preds_evaluated)):\n                yield {\n                    key: value[i]\n                    for key, value in six.iteritems(preds_evaluated)\n                }\n\n  def _assert_members_are_not_overridden(self):\n    \"\"\"Asserts members of `Estimator` are not overridden.\"\"\"\n    _assert_members_are_not_overridden(Estimator, self)\n\n  def export_saved_model(self,\n                         export_dir_base,\n                         serving_input_receiver_fn,\n                         assets_extra=None,\n                         as_text=False,\n                         checkpoint_path=None,\n                         experimental_mode=ModeKeys.PREDICT):\n    # pylint: disable=line-too-long\n    \"\"\"Exports inference graph as a `SavedModel` into the given dir.\n\n    For a detailed guide on SavedModel, see\n    [Using the SavedModel format]\n    (https://tensorflow.org/guide/saved_model#savedmodels_from_estimators).\n\n    This method builds a new graph by first calling the\n    `serving_input_receiver_fn` to obtain feature `Tensor`s, and then calling\n    this `Estimator`'s `model_fn` to generate the model graph based on those\n    features. It restores the given checkpoint (or, lacking that, the most\n    recent checkpoint) into this graph in a fresh session.  Finally it creates\n    a timestamped export directory below the given `export_dir_base`, and writes\n    a `SavedModel` into it containing a single `tf.MetaGraphDef` saved from this\n    session.\n\n    The exported `MetaGraphDef` will provide one `SignatureDef` for each\n    element of the `export_outputs` dict returned from the `model_fn`, named\n    using the same keys.  One of these keys is always\n    `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`,\n    indicating which signature will be served when a serving request does not\n    specify one. For each signature, the outputs are provided by the\n    corresponding `tf.estimator.export.ExportOutput`s, and the inputs are always\n    the input receivers provided by the `serving_input_receiver_fn`.\n\n    Extra assets may be written into the `SavedModel` via the `assets_extra`\n    argument.  This should be a dict, where each key gives a destination path\n    (including the filename) relative to the assets.extra directory.  The\n    corresponding value gives the full path of the source file to be copied.\n    For example, the simple case of copying a single file without renaming it\n    is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.\n\n    The experimental_mode parameter can be used to export a single\n    train/eval/predict graph as a `SavedModel`.\n    See `experimental_export_all_saved_models` for full docs.\n\n    Args:\n      export_dir_base: A string containing a directory in which to create\n        timestamped subdirectories containing exported `SavedModel`s.\n      serving_input_receiver_fn: A function that takes no argument and returns a\n        `tf.estimator.export.ServingInputReceiver` or\n        `tf.estimator.export.TensorServingInputReceiver`.\n      assets_extra: A dict specifying how to populate the assets.extra directory\n        within the exported `SavedModel`, or `None` if no extra assets are\n        needed.\n      as_text: whether to write the `SavedModel` proto in text format.\n      checkpoint_path: The checkpoint path to export.  If `None` (the default),\n        the most recent checkpoint found within the model directory is chosen.\n      experimental_mode: `tf.estimator.ModeKeys` value indicating with mode will\n        be exported. Note that this feature is experimental.\n\n    Returns:\n      The path to the exported directory as a bytes object.\n\n    Raises:\n      ValueError: if no `serving_input_receiver_fn` is provided, no\n      `export_outputs` are provided, or no checkpoint can be found.\n    \"\"\"\n    # pylint: enable=line-too-long\n    if not serving_input_receiver_fn:\n      raise ValueError('An input_receiver_fn must be defined.')\n\n    input_receiver_fn_map = {experimental_mode: serving_input_receiver_fn}\n\n    return self._export_all_saved_models(\n        export_dir_base,\n        input_receiver_fn_map,\n        assets_extra=assets_extra,\n        as_text=as_text,\n        checkpoint_path=checkpoint_path,\n        strip_default_attrs=True)\n\n  def experimental_export_all_saved_models(self,\n                                           export_dir_base,\n                                           input_receiver_fn_map,\n                                           assets_extra=None,\n                                           as_text=False,\n                                           checkpoint_path=None):\n    \"\"\"Exports a `SavedModel` with `tf.MetaGraphDefs` for each requested mode.\n\n    For each mode passed in via the `input_receiver_fn_map`,\n    this method builds a new graph by calling the `input_receiver_fn` to obtain\n    feature and label `Tensor`s. Next, this method calls the `Estimator`'s\n    `model_fn` in the passed mode to generate the model graph based on\n    those features and labels, and restores the given checkpoint\n    (or, lacking that, the most recent checkpoint) into the graph.\n    Only one of the modes is used for saving variables to the `SavedModel`\n    (order of preference: `tf.estimator.ModeKeys.TRAIN`,\n    `tf.estimator.ModeKeys.EVAL`, then\n    `tf.estimator.ModeKeys.PREDICT`), such that up to three\n    `tf.MetaGraphDefs` are saved with a single set of variables in a single\n    `SavedModel` directory.\n\n    For the variables and `tf.MetaGraphDefs`, a timestamped export directory\n    below `export_dir_base`, and writes a `SavedModel` into it containing the\n    `tf.MetaGraphDef` for the given mode and its associated signatures.\n\n    For prediction, the exported `MetaGraphDef` will provide one `SignatureDef`\n    for each element of the `export_outputs` dict returned from the `model_fn`,\n    named using the same keys.  One of these keys is always\n    `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`,\n    indicating which signature will be served when a serving request does not\n    specify one. For each signature, the outputs are provided by the\n    corresponding `tf.estimator.export.ExportOutput`s, and the inputs are always\n    the input receivers provided by the `serving_input_receiver_fn`.\n\n    For training and evaluation, the `train_op` is stored in an extra\n    collection, and loss, metrics, and predictions are included in a\n    `SignatureDef` for the mode in question.\n\n    Extra assets may be written into the `SavedModel` via the `assets_extra`\n    argument.  This should be a dict, where each key gives a destination path\n    (including the filename) relative to the assets.extra directory.  The\n    corresponding value gives the full path of the source file to be copied.\n    For example, the simple case of copying a single file without renaming it\n    is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.\n\n    Args:\n      export_dir_base: A string containing a directory in which to create\n        timestamped subdirectories containing exported `SavedModel`s.\n      input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to\n        `input_receiver_fn` mappings, where the `input_receiver_fn` is a\n        function that takes no arguments and returns the appropriate subclass of\n        `InputReceiver`.\n      assets_extra: A dict specifying how to populate the assets.extra directory\n        within the exported `SavedModel`, or `None` if no extra assets are\n        needed.\n      as_text: whether to write the `SavedModel` proto in text format.\n      checkpoint_path: The checkpoint path to export.  If `None` (the default),\n        the most recent checkpoint found within the model directory is chosen.\n\n    Returns:\n      The path to the exported directory as a bytes object.\n\n    Raises:\n      ValueError: if any `input_receiver_fn` is `None`, no `export_outputs`\n        are provided, or no checkpoint can be found.\n    \"\"\"\n    return self._export_all_saved_models(\n        export_dir_base,\n        input_receiver_fn_map,\n        assets_extra=assets_extra,\n        as_text=as_text,\n        checkpoint_path=checkpoint_path,\n        strip_default_attrs=True)\n\n  def _export_all_saved_models(self,\n                               export_dir_base,\n                               input_receiver_fn_map,\n                               assets_extra=None,\n                               as_text=False,\n                               checkpoint_path=None,\n                               strip_default_attrs=True):\n    \"\"\"Exports multiple modes in the model function to a SavedModel.\"\"\"\n    # TODO(b/65561022): Consider allowing multiple input_receiver_fns per mode.\n    with context.graph_mode():\n      if not checkpoint_path:\n        # Locate the latest checkpoint\n        checkpoint_path = self.latest_checkpoint()\n      if not checkpoint_path:\n        if self._warm_start_settings:\n          checkpoint_path = self._warm_start_settings.ckpt_to_initialize_from\n          if tf.compat.v1.gfile.IsDirectory(checkpoint_path):\n            checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)\n        else:\n          raise ValueError(\"Couldn't find trained model at {}.\".format(\n              self._model_dir))\n\n      export_dir = export_lib.get_timestamped_export_dir(export_dir_base)\n      temp_export_dir = export_lib.get_temp_export_dir(export_dir)\n\n      builder = tf.compat.v1.saved_model.Builder(temp_export_dir)\n\n      save_variables = True\n      # Note that the order in which we run here matters, as the first\n      # mode we pass through will be used to save the variables. We run TRAIN\n      # first, as that is also the mode used for checkpoints, and therefore\n      # we are not likely to have vars in PREDICT that are not in the checkpoint\n      # created by TRAIN.\n      if input_receiver_fn_map.get(ModeKeys.TRAIN):\n        self._add_meta_graph_for_mode(\n            builder,\n            input_receiver_fn_map,\n            checkpoint_path,\n            save_variables,\n            mode=ModeKeys.TRAIN,\n            strip_default_attrs=strip_default_attrs)\n        save_variables = False\n      if input_receiver_fn_map.get(ModeKeys.EVAL):\n        self._add_meta_graph_for_mode(\n            builder,\n            input_receiver_fn_map,\n            checkpoint_path,\n            save_variables,\n            mode=ModeKeys.EVAL,\n            strip_default_attrs=strip_default_attrs)\n        save_variables = False\n      if input_receiver_fn_map.get(ModeKeys.PREDICT):\n        self._add_meta_graph_for_mode(\n            builder,\n            input_receiver_fn_map,\n            checkpoint_path,\n            save_variables,\n            mode=ModeKeys.PREDICT,\n            strip_default_attrs=strip_default_attrs)\n        save_variables = False\n\n      if save_variables:\n        raise ValueError('No valid modes for exporting found. Got {}.'.format(\n            input_receiver_fn_map.keys()))\n\n      builder.save(as_text)\n\n      # Add the extra assets\n      if assets_extra:\n        assets_extra_path = os.path.join(\n            tf.compat.as_bytes(temp_export_dir),\n            tf.compat.as_bytes('assets.extra'))\n        for dest_relative, source in assets_extra.items():\n          dest_absolute = os.path.join(\n              tf.compat.as_bytes(assets_extra_path),\n              tf.compat.as_bytes(dest_relative))\n          dest_path = os.path.dirname(dest_absolute)\n          tf.compat.v1.gfile.MakeDirs(dest_path)\n          tf.compat.v1.gfile.Copy(source, dest_absolute)\n\n      tf.compat.v1.gfile.Rename(temp_export_dir, export_dir)\n      return export_dir\n\n  def _add_meta_graph_for_mode(self,\n                               builder,\n                               input_receiver_fn_map,\n                               checkpoint_path,\n                               save_variables=True,\n                               mode=ModeKeys.PREDICT,\n                               export_tags=None,\n                               check_variables=True,\n                               strip_default_attrs=True):\n    \"\"\"Loads variables and adds them along with a `tf.MetaGraphDef` for saving.\n\n    Args:\n      builder: instance of `tf.saved_modle.builder.SavedModelBuilder` that will\n        be used for saving.\n      input_receiver_fn_map: dict of `tf.estimator.ModeKeys` to\n        `input_receiver_fn` mappings, where the `input_receiver_fn` is a\n        function that takes no argument and returns the appropriate subclass of\n        `InputReceiver`.\n      checkpoint_path: The checkpoint path to export.\n      save_variables: bool, whether variables should be saved. If `False`, just\n        the `tf.MetaGraphDef` will be saved. Note that `save_variables` should\n        only be `True` for the first call to this function, and the\n        `SavedModelBuilder` will raise an error if that is not the case.\n      mode: `tf.estimator.ModeKeys` value indicating which mode will be\n        exported.\n      export_tags: The set of tags with which to save `tf.MetaGraphDef`. If\n        `None`, a default set will be selected to matched the passed mode.\n      check_variables: bool, whether to check the checkpoint has all variables.\n      strip_default_attrs: bool, whether to strip default attributes. This may\n        only be True when called from the deprecated V1\n        Estimator.export_savedmodel.\n\n    Raises:\n      ValueError: if `save_variables` is `True` and `check_variable` is `False`.\n    \"\"\"\n    if export_tags is None:\n      export_tags = export_lib.EXPORT_TAG_MAP[mode]\n    input_receiver_fn = input_receiver_fn_map[mode]\n\n    with tf.Graph().as_default() as g:\n      self._create_and_assert_global_step(g)\n      tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)\n\n      input_receiver = input_receiver_fn()\n\n      # Call the model_fn and collect the export_outputs.\n      estimator_spec = self._call_model_fn(\n          features=input_receiver.features,\n          labels=getattr(input_receiver, 'labels', None),\n          mode=mode,\n          config=self.config)\n\n      export_outputs = export_lib.export_outputs_for_mode(\n          mode=estimator_spec.mode,\n          serving_export_outputs=estimator_spec.export_outputs,\n          predictions=estimator_spec.predictions,\n          loss=estimator_spec.loss,\n          metrics=estimator_spec.eval_metric_ops)\n\n      # Build the SignatureDefs from receivers and all outputs\n      signature_def_map = export_lib.build_all_signature_defs(\n          input_receiver.receiver_tensors,\n          export_outputs,\n          getattr(input_receiver, 'receiver_tensors_alternatives', None),\n          serving_only=(mode == ModeKeys.PREDICT))\n\n      with tf.compat.v1.Session(config=self._session_config) as session:\n\n        if estimator_spec.scaffold.local_init_op is not None:\n          local_init_op = estimator_spec.scaffold.local_init_op\n        else:\n          local_init_op = tf.compat.v1.train.Scaffold.default_local_init_op()\n\n        # This saver will be used both for restoring variables now,\n        # and in saving out the metagraph below. This ensures that any\n        # Custom Savers stored with the Scaffold are passed through to the\n        # SavedModel for restore later.\n        if isinstance(estimator_spec.scaffold.saver, trackable_util.Checkpoint):\n          graph_saver = tf.compat.v1.train.Saver(\n              var_list=graph_view.ObjectGraphView(\n                  estimator_spec.scaffold.saver).frozen_saveable_objects(),\n              sharded=True)\n        else:\n          graph_saver = (\n              estimator_spec.scaffold.saver or\n              tf.compat.v1.train.Saver(sharded=True))\n\n        if save_variables and not check_variables:\n          raise ValueError('If `save_variables` is `True, `check_variables`'\n                           'must not be `False`.')\n        if check_variables:\n          try:\n            graph_saver.restore(session, checkpoint_path)\n          except tf.errors.NotFoundError as e:\n            msg = ('Could not load all requested variables from checkpoint. '\n                   'Please make sure your model_fn does not expect variables '\n                   'that were not saved in the checkpoint.\\n\\n'\n                   'Encountered error with mode `{}` while restoring '\n                   'checkpoint from: `{}`. Full Traceback:\\n\\n{}').format(\n                       mode, checkpoint_path, e)\n            raise ValueError(msg)\n\n        # We add the train op explicitly for now, so that we don't have to\n        # change the Builder public interface. Note that this is a no-op\n        # for prediction, where train_op is None.\n        builder._add_train_op(estimator_spec.train_op)  # pylint: disable=protected-access\n\n        meta_graph_kwargs = dict(\n            tags=export_tags,\n            signature_def_map=signature_def_map,\n            assets_collection=tf.compat.v1.get_collection(\n                tf.compat.v1.GraphKeys.ASSET_FILEPATHS),\n            main_op=local_init_op,\n            saver=graph_saver,\n            strip_default_attrs=strip_default_attrs)\n\n        if save_variables:\n          builder.add_meta_graph_and_variables(session, **meta_graph_kwargs)\n        else:\n          builder.add_meta_graph(**meta_graph_kwargs)\n\n  def _get_features_from_input_fn(self, input_fn, mode):\n    \"\"\"Extracts the `features` from return values of `input_fn`.\"\"\"\n    result = self._call_input_fn(input_fn, mode)\n    result, _, hooks = estimator_util.parse_input_fn_result(result)\n    self._validate_features_in_predict_input(result)\n    return result, hooks\n\n  def _validate_features_in_predict_input(self, result):\n    if not _has_dataset_or_queue_runner(result):\n      tf.compat.v1.logging.warning(\n          'Input graph does not use tf.data.Dataset or contain a '\n          'QueueRunner. That means predict yields forever. '\n          'This is probably a mistake.'\n      )\n\n  def _get_iterator_from_input_fn(self, input_fn, mode, distribution=None):\n    \"\"\"Calls `input_fn` and returns an iterator.\"\"\"\n    if distribution is not None:\n      # pylint: disable=g-long-lambda\n      iterator = distribution.make_input_fn_iterator(\n          lambda input_context: self._call_input_fn(input_fn, mode,\n                                                    input_context))\n      input_hooks = [\n          estimator_util.DistributedIteratorInitializerHook(iterator)\n      ]\n    else:\n      result = self._call_input_fn(input_fn, mode)\n      iterator = result.make_initializable_iterator()\n      input_hooks = [estimator_util._DatasetInitializerHook(iterator)]  # pylint: disable=protected-access\n    return iterator, input_hooks\n\n  def _get_features_and_labels_from_input_fn(self, input_fn, mode):\n    \"\"\"Extracts the `features` and labels from return values of `input_fn`.\"\"\"\n    return estimator_util.parse_input_fn_result(\n        self._call_input_fn(input_fn, mode))\n\n  def _extract_batch_length(self, preds_evaluated):\n    \"\"\"Extracts batch length of predictions.\"\"\"\n    batch_length = None\n    for key, value in six.iteritems(preds_evaluated):\n      batch_length = batch_length or value.shape[0]\n      if value.shape[0] != batch_length:\n        raise ValueError('Batch length of predictions should be same. %s has '\n                         'different batch length than others.' % key)\n    return batch_length\n\n  def _extract_keys(self, predictions, predict_keys):\n    \"\"\"Extracts `predict_keys` from `predictions`.\"\"\"\n    if not predict_keys:\n      return predictions\n    if not isinstance(predictions, dict):\n      raise ValueError(\n          'predict_keys argument is not valid in case of non-dict predictions.')\n    existing_keys = predictions.keys()\n    predictions = {\n        key: value\n        for key, value in six.iteritems(predictions)\n        if key in predict_keys\n    }\n    if not predictions:\n      raise ValueError('Expected to run at least one output from %s, '\n                       'provided %s.' % (existing_keys, predict_keys))\n    return predictions\n\n  def _create_global_step(self, graph):\n    \"\"\"Creates the global step tensor in graph.\n\n    The global step tensor must be an integer type with name 'global_step' and\n    be added to the collection `tf.GraphKeys.GLOBAL_STEP`.\n\n    Args:\n      graph: The graph in which to create the global step tensor.\n\n    Returns:\n      The global step `tf.Tensor`.\n    \"\"\"\n    return tf.compat.v1.train.create_global_step(graph)\n\n  def _create_and_assert_global_step(self, graph):\n    \"\"\"Creates and asserts properties of the global step.\n\n    Args:\n      graph: The graph in which to create the global step tensor.\n\n    Returns:\n      The global step `tf.Tensor`.\n    \"\"\"\n    step = self._create_global_step(graph)\n    assert step is tf.compat.v1.train.get_global_step()\n    assert step.dtype.is_integer\n    return step\n\n  def _call_input_fn(self, input_fn, mode, input_context=None):\n    \"\"\"Calls the input function.\n\n    Args:\n      input_fn: The input function.\n      mode: `tf.estimator.ModeKeys`\n\n    Returns:\n      The return value of the passed `input_fn`, which should be one of:\n\n        * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a\n          tuple `(features, labels)` with same constraints as below.\n        * A tuple `(features, labels)`: Where `features` is a `Tensor` or a\n          dictionary of string feature name to `Tensor` and `labels` is a\n          `Tensor` or a dictionary of string label name to `Tensor`. Both\n          `features` and `labels` are consumed by `model_fn`. They should\n          satisfy the expectation of `model_fn` from inputs.\n\n    Raises:\n      ValueError: if `input_fn` takes invalid arguments.\n    \"\"\"\n    input_fn_args = function_utils.fn_args(input_fn)\n    kwargs = {}\n    if 'mode' in input_fn_args:\n      kwargs['mode'] = mode\n    if 'params' in input_fn_args:\n      kwargs['params'] = self.params\n    if 'config' in input_fn_args:\n      kwargs['config'] = self.config\n    if input_context and 'input_context' in input_fn_args:\n      tf.compat.v1.logging.info(\n          'The `input_fn` accepts an `input_context` which will '\n          'be given by DistributionStrategy')\n      kwargs['input_context'] = input_context\n    with tf.compat.v1.device('/cpu:0'):\n      return input_fn(**kwargs)\n\n  def _call_model_fn(self, features, labels, mode, config):\n    \"\"\"Calls model function.\n\n    Args:\n      features: features dict.\n      labels: labels dict.\n      mode: `tf.estimator.ModeKeys`\n      config: `tf.estimator.RunConfig`\n\n    Returns:\n      An `tf.estimator.EstimatorSpec` object.\n\n    Raises:\n      ValueError: if `model_fn` returns invalid objects.\n    \"\"\"\n    model_fn_args = function_utils.fn_args(self._model_fn)\n    kwargs = {}\n    if 'labels' in model_fn_args:\n      kwargs['labels'] = labels\n    else:\n      if labels is not None:\n        raise ValueError(\n            'model_fn does not take labels, but input_fn returns labels.')\n    if 'mode' in model_fn_args:\n      kwargs['mode'] = mode\n    if 'params' in model_fn_args:\n      kwargs['params'] = self.params\n    if 'config' in model_fn_args:\n      kwargs['config'] = config\n\n    tf.compat.v1.logging.info('Calling model_fn.')\n    model_fn_results = self._model_fn(features=features, **kwargs)\n    tf.compat.v1.logging.info('Done calling model_fn.')\n\n    if not isinstance(model_fn_results, model_fn_lib.EstimatorSpec):\n      raise ValueError('model_fn should return an EstimatorSpec.')\n\n    return model_fn_results\n\n  def _train_model(self, input_fn, hooks, saving_listeners):\n    if self._train_distribution:\n      return self._train_model_distributed(input_fn, hooks, saving_listeners)\n    else:\n      return self._train_model_default(input_fn, hooks, saving_listeners)\n\n  def _train_model_default(self, input_fn, hooks, saving_listeners):\n    \"\"\"Initiate training with `input_fn`, without `DistributionStrategies`.\n\n    Args:\n      input_fn: A function that provides input data for training as minibatches.\n      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for\n        callbacks inside the training loop.\n      saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used\n        for callbacks that run immediately before or after checkpoint savings.\n\n    Returns:\n      Loss from training\n    \"\"\"\n    worker_hooks = []\n    with tf.Graph().as_default() as g, g.device(self._device_fn):\n      tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)\n      global_step_tensor = self._create_and_assert_global_step(g)\n\n      # Skip creating a read variable if _create_and_assert_global_step\n      # returns None (e.g. tf.contrib.estimator.SavedModelEstimator).\n      if global_step_tensor is not None:\n        training_util._get_or_create_global_step_read(g)  # pylint: disable=protected-access\n\n      features, labels, input_hooks = (\n          self._get_features_and_labels_from_input_fn(input_fn, ModeKeys.TRAIN))\n      worker_hooks.extend(input_hooks)\n      estimator_spec = self._call_model_fn(features, labels, ModeKeys.TRAIN,\n                                           self.config)\n      global_step_tensor = tf.compat.v1.train.get_global_step(g)\n      return self._train_with_estimator_spec(estimator_spec, worker_hooks,\n                                             hooks, global_step_tensor,\n                                             saving_listeners)\n\n  def _train_model_distributed(self, input_fn, hooks, saving_listeners):\n    \"\"\"Initiate training with `input_fn`, using `DistributionStrategies`.\n\n    Args:\n      input_fn: A function that provides input data for training as minibatches.\n      hooks: List of `tf.train.SessionRunHook` subclass instances. Used for\n        callbacks inside the training loop.\n      saving_listeners: list of `tf.train.CheckpointSaverListener` objects. Used\n        for callbacks that run immediately before or after checkpoint savings.\n\n    Returns:\n      Loss from training\n    \"\"\"\n    # pylint: disable=protected-access\n    if (hasattr(self._config, '_distribute_coordinator_mode') and\n        self._config._distribute_coordinator_mode):  # pylint: disable=protected-access\n      distribute_coordinator_training.estimator_train(\n          self,\n          lambda est, s, train_hooks: est._actual_train_model_distributed(  # pylint: disable=g-long-lambda\n              s, input_fn, train_hooks, saving_listeners),\n          hooks)\n      return self\n    else:\n      self._config._train_distribute.configure(self._config.session_config)\n      return self._actual_train_model_distributed(\n          self._config._train_distribute, input_fn, hooks, saving_listeners)\n    # pylint: enable=protected-access\n\n  def _actual_train_model_distributed(self, strategy, input_fn, hooks,\n                                      saving_listeners):\n    \"\"\"That method that does actual training with distribution strategy.\"\"\"\n    # TODO(sourabhbajaj): Remove this hack once we migrate the other strategies\n    # to use the new API\n    is_tpu_strategy = strategy.__class__.__name__.startswith('TPUStrategy')\n\n    worker_hooks = []\n    with tf.Graph().as_default() as g:\n      # We want to create the iterations variable outside the distribution scope\n      # as that is just stored on the host and mainly used to drive the loop\n      # and doesn't need to be a Mirrored/Device variable.\n      if is_tpu_strategy:\n        steps_per_run_variable = training.get_or_create_steps_per_run_variable()\n\n      # Set flag on the distribution strategy so that optimizer v1 is\n      # distribution aware and scales the losses by number of replicas.\n      # This is required only for backward compatibility with estimator and\n      # V1 optimizer. TF2 will not do this scaling.\n      if hasattr(strategy, '_scale_loss_for_estimator_enabled'):\n        scale_ctx = strategy._scale_loss_for_estimator_enabled()  # pylint: disable=protected-access\n      else:\n        # TODO(psv): Remove this clause after estimator repo gets the\n        # distribute library changes related to loss scaling.\n        @tf_contextlib.contextmanager\n        def nullcontextmanager():\n          yield\n\n        scale_ctx = nullcontextmanager()\n\n      with strategy.scope(), scale_ctx:\n        tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)\n        iterator, input_hooks = self._get_iterator_from_input_fn(\n            input_fn, ModeKeys.TRAIN, strategy)\n        worker_hooks.extend(input_hooks)\n        global_step_tensor = self._create_and_assert_global_step(g)\n        # we want to add to the global collection in the main thread not the\n        # replica threads.\n        tf.compat.v1.add_to_collection(\n            training_util.GLOBAL_STEP_READ_KEY,\n            strategy.extended.read_var(global_step_tensor))\n\n        if is_tpu_strategy:\n          # Create a step_fn from the train_op of grouped_estimator_spec\n          def step_fn(ctx, inputs):\n            \"\"\"A single step that is passed to run_on_dataset.\"\"\"\n            if isinstance(inputs, tuple):\n              features, labels = inputs\n            else:\n              features = inputs\n              labels = None\n            estimator_spec = strategy.extended.call_for_each_replica(\n                self._call_model_fn,\n                args=(features, labels, ModeKeys.TRAIN, self.config))\n            ctx.set_last_step_output(\n                name='loss',\n                output=estimator_spec.loss,\n                reduce_op=_get_loss_reduce_op_for_reporting())\n            ctx.set_non_tensor_output(\n                name='estimator_spec', output=estimator_spec)\n            return estimator_spec.train_op\n\n          # Create new train_op post graph rewrites\n          initial_training_loss = tf.constant(1e7)\n          ctx = strategy.extended.experimental_run_steps_on_iterator(\n              step_fn,\n              iterator,\n              iterations=steps_per_run_variable,\n              initial_loop_values={'loss': initial_training_loss})\n          distributed_train_op = ctx.run_op\n          loss = ctx.last_step_outputs['loss']\n          grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']\n        else:\n          features, labels = estimator_util.parse_iterator_result(\n              iterator.get_next())\n          grouped_estimator_spec = strategy.extended.call_for_each_replica(\n              self._call_model_fn,\n              args=(\n                  features,\n                  labels,  # although this will be None it seems\n                  ModeKeys.TRAIN,\n                  self.config))\n          loss = strategy.reduce(\n              _get_loss_reduce_op_for_reporting(),\n              grouped_estimator_spec.loss,\n              axis=None)\n          distributed_train_op = grouped_estimator_spec.train_op\n\n        scaffold = _combine_distributed_scaffold(\n            grouped_estimator_spec.scaffold, strategy)\n\n        # TODO(yuefengz): add a test for unwrapping per_device_hooks.\n        def get_hooks_from_the_first_device(per_device_hooks):\n          return [\n              self._train_distribution.experimental_local_results(\n                  per_device_hook)[0] for per_device_hook in per_device_hooks\n          ]\n\n        training_hooks = get_hooks_from_the_first_device(\n            grouped_estimator_spec.training_hooks)\n        training_chief_hooks = get_hooks_from_the_first_device(\n            grouped_estimator_spec.training_chief_hooks)\n        estimator_spec = model_fn_lib.EstimatorSpec(\n            mode=grouped_estimator_spec.mode,\n            loss=loss,\n            train_op=strategy.group(distributed_train_op),\n            training_hooks=training_hooks,\n            training_chief_hooks=training_chief_hooks,\n            scaffold=scaffold)\n        return self._train_with_estimator_spec(estimator_spec, worker_hooks,\n                                               hooks, global_step_tensor,\n                                               saving_listeners)\n\n  def _train_with_estimator_spec_distributed(self, estimator_spec, worker_hooks,\n                                             saving_listener):\n    \"\"\"Train a model with the given Estimator Spec and Distribution Strategy.\"\"\"\n    if saving_listener:\n      raise ValueError('Saving listenor is not supported by the current '\n                       'Distribution Strategies.')\n    #TODO: consolidate code duplication in _train_with_estimator_spec\n    with training.MonitoredTrainingSession(\n        master=self._config.master,\n        is_chief=self._config.is_chief,\n        checkpoint_dir=self._model_dir,\n        scaffold=estimator_spec.scaffold,\n        hooks=worker_hooks,\n        chief_only_hooks=tuple(estimator_spec.training_chief_hooks),\n        save_checkpoint_secs=self._config.save_checkpoints_secs,\n        save_checkpoint_steps=self._config.save_checkpoints_steps,\n        save_summaries_steps=self._config.save_summary_steps,\n        config=self._session_config,\n        max_wait_secs=self._config.session_creation_timeout_secs,\n        log_step_count_steps=self._config.log_step_count_steps,\n        save_graph_def=self._config.checkpoint_save_graph_def) as mon_sess:\n      loss = None\n      current_step = 0\n      while not mon_sess.should_stop():\n        current_step += 1\n        # just as keras(https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/engine/training.py#L1093),\n        # trace should be enabled for every step\n        with trace.Trace('train', step_num=current_step, _r=1):\n          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])\n      if current_step == 0:\n        tf.compat.v1.logging.warn('Training with estimator made no steps. '\n                                  'Perhaps input is empty or misspecified.')\n    return loss\n\n  def _train_with_estimator_spec(self, estimator_spec, worker_hooks, hooks,\n                                 global_step_tensor, saving_listeners):\n    \"\"\"Train a model with the given Estimator Spec.\"\"\"\n    if (self._warm_start_settings and\n        not tf.train.latest_checkpoint(self._model_dir)):\n      tf.compat.v1.logging.info('Warm-starting with WarmStartSettings: %s' %\n                                (self._warm_start_settings,))\n      tf.compat.v1.train.warm_start(*self._warm_start_settings)\n    # Check if the user created a loss summary, and add one if they didn't.\n    # We assume here that the summary is called 'loss'. If it is not, we will\n    # make another one with the name 'loss' to ensure it shows up in the right\n    # graph in TensorBoard.\n    if not any([\n        x.op.name == 'loss' for x in ops.get_collection(ops.GraphKeys.SUMMARIES)\n    ]):\n      summary.scalar('loss', estimator_spec.loss)\n    ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)\n    worker_hooks.extend(hooks)\n    worker_hooks.append(tf.compat.v1.train.NanTensorHook(estimator_spec.loss))\n    if self._config.log_step_count_steps is not None:\n      worker_hooks.append(\n          tf.compat.v1.train.LoggingTensorHook(\n              {\n                  'loss': estimator_spec.loss,\n                  'step': global_step_tensor\n              },\n              every_n_iter=self._config.log_step_count_steps))\n    worker_hooks.extend(estimator_spec.training_hooks)\n\n    if not (estimator_spec.scaffold.saver or\n            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SAVERS)):\n      tf.compat.v1.add_to_collection(\n          tf.compat.v1.GraphKeys.SAVERS,\n          tf.compat.v1.train.Saver(\n              sharded=True,\n              max_to_keep=self._config.keep_checkpoint_max,\n              keep_checkpoint_every_n_hours=(\n                  self._config.keep_checkpoint_every_n_hours),\n              defer_build=True,\n              save_relative_paths=True))\n\n    if (self._config.cluster_spec and type(\n        self._train_distribution).__name__ in ('CollectiveAllReduceStrategy',\n                                               'CollectiveAllReduceStrategyV1',\n                                               'MultiWorkerMirroredStrategy')):\n      return self._train_with_estimator_spec_distributed(\n          estimator_spec, worker_hooks, saving_listeners)\n\n    chief_hooks = []\n    all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)\n    saver_hooks = [\n        h for h in all_hooks\n        if isinstance(h, tf.compat.v1.train.CheckpointSaverHook)\n    ]\n    if (self._config.save_checkpoints_secs or\n        self._config.save_checkpoints_steps):\n      if not saver_hooks:\n        chief_hooks = [\n            tf.compat.v1.train.CheckpointSaverHook(\n                self._model_dir,\n                save_secs=self._config.save_checkpoints_secs,\n                save_steps=self._config.save_checkpoints_steps,\n                scaffold=estimator_spec.scaffold,\n                save_graph_def=self._config.checkpoint_save_graph_def)\n        ]\n        saver_hooks = [chief_hooks[0]]\n    if saving_listeners:\n      if not saver_hooks:\n        raise ValueError(\n            'There should be a CheckpointSaverHook to use saving_listeners. '\n            'Please set one of the RunConfig.save_checkpoints_steps or '\n            'RunConfig.save_checkpoints_secs.')\n      else:\n        # It is expected to have one CheckpointSaverHook. If multiple, we pick\n        # up the first one to add listener.\n        for listener in saving_listeners:\n          # pylint: disable=protected-access\n          if listener not in saver_hooks[0]._listeners:\n            saver_hooks[0]._listeners.append(listener)\n          # pylint: disable=protected-access\n\n    # Add summary hooks to worker 0 if we are running with a master, to ensure\n    # that summaries are written at correct intervals even with long-running\n    # evaluations.\n    save_summary_steps = self._config.save_summary_steps\n    log_step_count_steps = self._config.log_step_count_steps\n\n    # Check existence of appropriate cluster spec fields, as well as master and\n    # worker nodes. As master also performs evaluation, summary writing must\n    # occur on a different node. The presence of a worker is also checked to\n    # prevent reassigning hooks for single-replica jobs with just a master node.\n    if (self._config.cluster_spec and self._config.cluster_spec.jobs and\n        (run_config.TaskType.WORKER in self._config.cluster_spec.jobs) and\n        (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):\n      # Update config values to prevent the default hooks from being created on\n      # the master or other workers.\n      save_summary_steps = 0\n      log_step_count_steps = None\n\n      if (self._config.task_type == run_config.TaskType.WORKER and\n          self._config.task_id == 0):\n        if (self._config.save_summary_steps and\n            self._config.save_summary_steps > 0):\n          worker_hooks.append(\n              tf.compat.v1.train.SummarySaverHook(\n                  save_steps=self._config.save_summary_steps,\n                  output_dir=self._config.model_dir,\n                  scaffold=estimator_spec.scaffold))\n\n        if (self._config.log_step_count_steps and\n            self._config.log_step_count_steps > 0):\n          worker_hooks.append(\n              tf.compat.v1.train.StepCounterHook(\n                  every_n_steps=self._config.log_step_count_steps,\n                  output_dir=self._config.model_dir))\n\n    with training.MonitoredTrainingSession(\n        master=self._config.master,\n        is_chief=self._config.is_chief,\n        checkpoint_dir=self._model_dir,\n        scaffold=estimator_spec.scaffold,\n        hooks=worker_hooks,\n        chief_only_hooks=(tuple(chief_hooks) +\n                          tuple(estimator_spec.training_chief_hooks)),\n        save_checkpoint_secs=0,  # Saving is handled by a hook.\n        save_summaries_steps=save_summary_steps,\n        config=self._session_config,\n        max_wait_secs=self._config.session_creation_timeout_secs,\n        log_step_count_steps=log_step_count_steps,\n        save_graph_def=self._config.checkpoint_save_graph_def) as mon_sess:\n      loss = None\n      current_step = 0\n      while not mon_sess.should_stop():\n        current_step += 1\n        # just as keras(https://github.com/tensorflow/tensorflow/blob/v2.4.1/tensorflow/python/keras/engine/training.py#L1093),\n        # trace should be enabled for every step\n        with trace.Trace('train', step_num=current_step, _r=1):\n          _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])\n      if current_step == 0:\n        tf.compat.v1.logging.warn('Training with estimator made no steps. '\n                                  'Perhaps input is empty or misspecified.')\n    return loss\n\n  def _evaluate_build_graph(self, input_fn, hooks=None, checkpoint_path=None):\n    \"\"\"Builds the graph and related hooks to run evaluation.\"\"\"\n    tf.compat.v1.random.set_random_seed(self._config.tf_random_seed)\n    self._create_and_assert_global_step(tf.compat.v1.get_default_graph())\n\n    if self._eval_distribution:\n      (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (\n          self._call_model_fn_eval_distributed(input_fn, self.config))\n    else:\n      (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict) = (\n          self._call_model_fn_eval(input_fn, self.config))\n\n    global_step_tensor = tf.compat.v1.train.get_global_step(\n        tf.compat.v1.get_default_graph())\n    # Call to warm_start has to be after model_fn is called.\n    self._maybe_warm_start(checkpoint_path)\n\n    if tf.compat.v1.GraphKeys.GLOBAL_STEP in eval_dict:\n      raise ValueError(\n          'Metric with name `global_step` is not allowed, because Estimator '\n          'already defines a default metric with the same name.')\n    eval_dict[tf.compat.v1.GraphKeys.GLOBAL_STEP] = global_step_tensor\n\n    all_hooks = list(input_hooks)\n    all_hooks.extend(hooks)\n    all_hooks.extend(list(evaluation_hooks or []))\n    # New local variables have been added, so update the estimator spec's\n    # local init op if it was defined.\n    if scaffold and scaffold.local_init_op:\n      # Ensure that eval step has been created before updating local init op.\n      evaluation._get_or_create_eval_step()  # pylint: disable=protected-access\n\n      scaffold = tf.compat.v1.train.Scaffold(\n          local_init_op=tf.group(\n              scaffold.local_init_op,\n              tf.compat.v1.train.Scaffold.default_local_init_op()),\n          copy_from_scaffold=scaffold)\n\n    return scaffold, update_op, eval_dict, all_hooks\n\n  def _call_model_fn_eval(self, input_fn, config):\n    \"\"\"Call model_fn for evaluation and handle return values.\"\"\"\n    features, labels, input_hooks = self._get_features_and_labels_from_input_fn(\n        input_fn, ModeKeys.EVAL)\n\n    estimator_spec = self._call_model_fn(features, labels, ModeKeys.EVAL,\n                                         config)\n    eval_metric_ops = _verify_and_create_loss_metric(\n        estimator_spec.eval_metric_ops, estimator_spec.loss)\n    update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops)\n    return (estimator_spec.scaffold, estimator_spec.evaluation_hooks,\n            input_hooks, update_op, eval_dict)\n\n  def _call_model_fn_eval_distributed(self, input_fn, config):\n    \"\"\"Call model_fn in distribution mode and handle return values.\"\"\"\n\n    iterator, input_hooks = self._get_iterator_from_input_fn(\n        input_fn, ModeKeys.EVAL, self._eval_distribution)\n\n    is_tpu_strategy = (\n        self._eval_distribution.__class__.__name__.startswith('TPUStrategy'))\n\n    if is_tpu_strategy:\n      steps_per_run_variable = training.get_or_create_steps_per_run_variable()\n\n      def step_fn(ctx, inputs):\n        \"\"\"Runs one step of the eval computation and captures outputs.\"\"\"\n        if isinstance(inputs, tuple):\n          features, labels = inputs\n        else:\n          features = inputs\n          labels = None\n        estimator_spec = self._eval_distribution.extended.call_for_each_replica(\n            self._call_model_fn, args=(features, labels, ModeKeys.EVAL, config))\n        eval_metric_ops = _verify_and_create_loss_metric(\n            estimator_spec.eval_metric_ops, estimator_spec.loss,\n            self._eval_distribution)\n        update_op, eval_dict = _extract_metric_update_ops(\n            eval_metric_ops, self._eval_distribution)\n        ctx.set_non_tensor_output(name='estimator_spec', output=estimator_spec)\n        ctx.set_non_tensor_output(name='eval_dict', output=eval_dict)\n        return update_op\n\n      # TODO(priyag): Fix eval step hook to account for steps_per_run.\n      ctx = self._eval_distribution.extended.experimental_run_steps_on_iterator(\n          step_fn, iterator, iterations=steps_per_run_variable)\n      update_op = ctx.run_op\n      eval_dict = ctx.non_tensor_outputs['eval_dict']\n      grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']\n    else:\n      features, labels = estimator_util.parse_iterator_result(\n          iterator.get_next())\n      grouped_estimator_spec = (\n          self._eval_distribution.extended.call_for_each_replica(\n              self._call_model_fn,\n              args=(features, labels, ModeKeys.EVAL, config)))\n      eval_metric_ops = _verify_and_create_loss_metric(\n          grouped_estimator_spec.eval_metric_ops, grouped_estimator_spec.loss,\n          self._eval_distribution)\n      update_op, eval_dict = _extract_metric_update_ops(eval_metric_ops,\n                                                        self._eval_distribution)\n\n    scaffold = _combine_distributed_scaffold(grouped_estimator_spec.scaffold,\n                                             self._eval_distribution)\n\n    def get_hooks_from_the_first_device(per_device_hooks):\n      return [\n          self._eval_distribution.experimental_local_results(per_device_hook)[0]\n          for per_device_hook in per_device_hooks\n      ]\n\n    evaluation_hooks = get_hooks_from_the_first_device(\n        grouped_estimator_spec.evaluation_hooks)\n\n    return (scaffold, evaluation_hooks, input_hooks, update_op, eval_dict)\n\n  def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict,\n                    all_hooks, output_dir):\n    \"\"\"Run evaluation.\"\"\"\n    eval_results = evaluation._evaluate_once(  # pylint: disable=protected-access\n        checkpoint_path=checkpoint_path,\n        master=self._config.evaluation_master,\n        scaffold=scaffold,\n        eval_ops=update_op,\n        final_ops=eval_dict,\n        hooks=all_hooks,\n        config=self._session_config)\n\n    current_global_step = eval_results[tf.compat.v1.GraphKeys.GLOBAL_STEP]\n\n    _write_dict_to_summary(\n        output_dir=output_dir,\n        dictionary=eval_results,\n        current_global_step=current_global_step)\n\n    if checkpoint_path:\n      _write_checkpoint_path_to_summary(\n          output_dir=output_dir,\n          checkpoint_path=checkpoint_path,\n          current_global_step=current_global_step)\n\n    return eval_results\n\n  def _maybe_warm_start(self, checkpoint_path):\n    if not checkpoint_path and self._warm_start_settings:\n      tf.compat.v1.logging.info('Warm-starting with WarmStartSettings: %s' %\n                                (self._warm_start_settings,))\n      tf.compat.v1.train.warm_start(*self._warm_start_settings)\n\n  @deprecation.deprecated(\n      None, 'This function has been renamed, use `export_saved_model` instead.')\n  def export_savedmodel(self,\n                        export_dir_base,\n                        serving_input_receiver_fn,\n                        assets_extra=None,\n                        as_text=False,\n                        checkpoint_path=None,\n                        strip_default_attrs=False):\n    # pylint: disable=line-too-long\n    \"\"\"Exports inference graph as a `SavedModel` into the given dir.\n\n    For a detailed guide, see\n    [SavedModel from\n    Estimators.](https://www.tensorflow.org/guide/estimator#savedmodels_from_estimators).\n\n    This method builds a new graph by first calling the\n    `serving_input_receiver_fn` to obtain feature `Tensor`s, and then calling\n    this `Estimator`'s `model_fn` to generate the model graph based on those\n    features. It restores the given checkpoint (or, lacking that, the most\n    recent checkpoint) into this graph in a fresh session.  Finally it creates\n    a timestamped export directory below the given `export_dir_base`, and writes\n    a `SavedModel` into it containing a single `tf.MetaGraphDef` saved from this\n    session.\n\n    The exported `MetaGraphDef` will provide one `SignatureDef` for each\n    element of the `export_outputs` dict returned from the `model_fn`, named\n    using the same keys.  One of these keys is always\n    `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`,\n    indicating which signature will be served when a serving request does not\n    specify one. For each signature, the outputs are provided by the\n    corresponding `tf.estimator.export.ExportOutput`s, and the inputs are always\n    the input receivers provided by the `serving_input_receiver_fn`.\n\n    Extra assets may be written into the `SavedModel` via the `assets_extra`\n    argument.  This should be a dict, where each key gives a destination path\n    (including the filename) relative to the assets.extra directory.  The\n    corresponding value gives the full path of the source file to be copied.\n    For example, the simple case of copying a single file without renaming it\n    is specified as `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.\n\n    Args:\n      export_dir_base: A string containing a directory in which to create\n        timestamped subdirectories containing exported `SavedModel`s.\n      serving_input_receiver_fn: A function that takes no argument and returns a\n        `tf.estimator.export.ServingInputReceiver` or\n        `tf.estimator.export.TensorServingInputReceiver`.\n      assets_extra: A dict specifying how to populate the assets.extra directory\n        within the exported `SavedModel`, or `None` if no extra assets are\n        needed.\n      as_text: whether to write the `SavedModel` proto in text format.\n      checkpoint_path: The checkpoint path to export.  If `None` (the default),\n        the most recent checkpoint found within the model directory is chosen.\n      strip_default_attrs: Boolean. If `True`, default-valued attributes will be\n        removed from the `NodeDef`s. For a detailed guide, see [Stripping\n        Default-Valued Attributes](\n        https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).\n\n    Returns:\n      The path to the exported directory as a bytes object.\n\n    Raises:\n      ValueError: if no `serving_input_receiver_fn` is provided, no\n      `export_outputs` are provided, or no checkpoint can be found.\n    \"\"\"\n    # pylint: enable=line-too-long\n    if not serving_input_receiver_fn:\n      raise ValueError('An input_receiver_fn must be defined.')\n\n    return self._export_all_saved_models(\n        export_dir_base, {ModeKeys.PREDICT: serving_input_receiver_fn},\n        assets_extra=assets_extra,\n        as_text=as_text,\n        checkpoint_path=checkpoint_path,\n        strip_default_attrs=strip_default_attrs)\n\n\n@estimator_export('estimator.Estimator', v1=[])  # pylint: disable=missing-docstring\nclass EstimatorV2(Estimator):\n  __doc__ = Estimator.__doc__\n\n  export_savedmodel = deprecation.hide_attribute_from_api(\n      '`Estimator.export_savedmodel` has been deprecated. Please use '\n      '`export_saved_model` instead.')\n\n  def _assert_members_are_not_overridden(self):\n    \"\"\"Asserts members of `Estimator` are not overridden.\"\"\"\n    _assert_members_are_not_overridden(EstimatorV2, self)\n\n\ndef _get_loss_reduce_op_for_reporting():\n  graph = tf.compat.v1.get_default_graph()\n  if getattr(graph, '_is_loss_scaled_by_optimizer', False):  # pylint: disable=protected-access\n    return tf.compat.v1.distribute.get_loss_reduction()\n  return tf.distribute.ReduceOp.SUM\n\n\ndef _assert_members_are_not_overridden(cls, obj):\n  \"\"\"Assert Estimator methods are not overwritten.\"\"\"\n  # TPUEstimator is special cased (owned by TF).\n  if obj.__class__.__name__ == 'TPUEstimator':\n    return\n\n  allowed_overrides = set([\n      'model_fn', '_create_and_assert_global_step', '_export_all_saved_models',\n      '_tf_api_names', '_tf_api_names_v1', '_estimator_api_names',\n      '_estimator_api_names_v1', '_estimator_api_constants',\n      '_estimator_api_constants_v1', 'latest_checkpoint'\n  ])\n\n  estimator_members = set([m for m in dir(cls) if not m.startswith('__')])\n  subclass_members = set(obj.__class__.__dict__.keys())\n  common_members = estimator_members & subclass_members - allowed_overrides\n  overridden_members = [\n      m for m in common_members if getattr(cls, m) != getattr(obj.__class__, m)\n  ]\n  if overridden_members:\n    raise ValueError(\n        'Subclasses of Estimator cannot override members of Estimator. '\n        '{} does override {}'.format(obj.__class__, overridden_members))\n\n\ndef _verify_and_create_loss_metric(eval_metric_ops, loss, distribution=None):\n  \"\"\"Creates a metric for loss and throws an error if one already exists.\"\"\"\n  if model_fn_lib.LOSS_METRIC_KEY in eval_metric_ops:\n    raise ValueError(\n        'Metric with name \"%s\" is not allowed, because Estimator ' %\n        (model_fn_lib.LOSS_METRIC_KEY) +\n        'already defines a default metric with the same name.')\n\n  if distribution is None:\n    loss_metric = tf.compat.v1.metrics.mean(loss)\n  else:\n    loss_metric = distribution.extended.call_for_each_replica(\n        tf.compat.v1.metrics.mean, args=(loss,))\n  eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric\n  return eval_metric_ops\n\n\ndef maybe_overwrite_model_dir_and_session_config(config, model_dir):\n  \"\"\"Overwrite estimator config by `model_dir` and `session_config` if needed.\n\n  Args:\n    config: Original estimator config.\n    model_dir: Estimator model checkpoint directory.\n\n  Returns:\n    Overwritten estimator config.\n\n  Raises:\n    ValueError: Model directory inconsistent between `model_dir` and `config`.\n  \"\"\"\n\n  if config is None:\n    config = run_config.RunConfig()\n    tf.compat.v1.logging.info('Using default config.')\n  if not isinstance(config, run_config.RunConfig):\n    raise ValueError(\n        'config must be an instance of `RunConfig`, but provided %s.' % config)\n\n  if config.session_config is None:\n    session_config = run_config.get_default_session_config()\n    config = run_config.RunConfig.replace(config, session_config=session_config)\n\n  model_dir = run_config.path_to_str(model_dir)\n  if model_dir is not None:\n    if (getattr(config, 'model_dir', None) is not None and\n        config.model_dir != model_dir):\n      raise ValueError(\n          '`model_dir` are set both in constructor and `RunConfig`, but with '\n          \"different values. In constructor: '{}', in `RunConfig`: \"\n          \"'{}' \".format(model_dir, config.model_dir))\n  if model_dir:\n    config = run_config.RunConfig.replace(config, model_dir=model_dir)\n  elif getattr(config, 'model_dir', None) is None:\n    model_dir = tempfile.mkdtemp()\n    tf.compat.v1.logging.warn('Using temporary folder as model directory: %s',\n                              model_dir)\n    config = run_config.RunConfig.replace(config, model_dir=model_dir)\n\n  return config\n\n\ndef create_per_replica_ready_for_local_init_op(scaffold):\n  \"\"\"Create a `tf.train.Scaffold.ready_for_local_init_op` inside a replica.\"\"\"\n  if scaffold.ready_for_local_init_op:\n    return scaffold.ready_for_local_init_op\n\n  def default_ready_for_local_init_op():\n    return tf.compat.v1.report_uninitialized_variables(\n        tf.compat.v1.global_variables())\n\n  return tf.compat.v1.train.Scaffold.get_or_default(\n      'ready_for_local_init_op', tf.compat.v1.GraphKeys.READY_FOR_LOCAL_INIT_OP,\n      default_ready_for_local_init_op)\n\n\ndef _combine_distributed_scaffold(grouped_scaffold, distribution):\n  \"\"\"Combines scaffold(s) returned from `call_for_each_replica`.\"\"\"\n\n  # TODO(anjalisridhar): Figure out how to resolve the following scaffold\n  # parameters: init_feed_dict, init_fn.\n  scaffold_list = distribution.experimental_local_results(grouped_scaffold)\n  init_feed_dict = [\n      s.init_feed_dict for s in scaffold_list if s.init_feed_dict is not None\n  ]\n  if init_feed_dict:\n    init_feed_dict = distribution.group(init_feed_dict)\n  else:\n    init_feed_dict = None\n\n  init_fn = [\n      s._user_init_fn for s in scaffold_list if s._user_init_fn is not None  # pylint: disable=protected-access\n  ]\n  if init_fn:\n    init_fn = init_fn[0]\n  else:\n    init_fn = None\n\n  init_op = [s.init_op for s in scaffold_list if s.init_op is not None]\n  if init_op:\n    init_op = distribution.group(init_op)\n  else:\n    init_op = None\n\n  def _unwrap_and_concat(value):\n    value = tf.nest.flatten(distribution.experimental_local_results(value))\n    if len(value) != 1:\n      return tf.concat(value, 0)\n    return value[0]\n\n  ready_op = distribution.extended.call_for_each_replica(\n      lambda scaffold: scaffold.ready_op, args=(grouped_scaffold,))\n  if ready_op is not None:\n    ready_op = _unwrap_and_concat(ready_op)\n\n  ready_for_local_init_op = distribution.extended.call_for_each_replica(\n      create_per_replica_ready_for_local_init_op, args=(grouped_scaffold,))\n  if ready_for_local_init_op is not None:\n    ready_for_local_init_op = _unwrap_and_concat(ready_for_local_init_op)\n  else:\n    ready_for_local_init_op = None\n\n  local_init_op = [\n      s.local_init_op for s in scaffold_list if s.local_init_op is not None\n  ]\n  if local_init_op:\n    local_init_op = distribution.group(local_init_op)\n  else:\n    local_init_op = None\n\n  summary_op = [s.summary_op for s in scaffold_list if s.summary_op is not None]\n  if summary_op:\n    summary_op = distribution.group(summary_op)\n  else:\n    summary_op = None\n\n  savers = [s.saver for s in scaffold_list if s.saver is not None]\n  if savers:\n    saver = savers[0]\n  else:\n    saver = None\n\n  scaffold = tf.compat.v1.train.Scaffold(\n      init_op=init_op,\n      ready_op=ready_op,\n      ready_for_local_init_op=ready_for_local_init_op,\n      local_init_op=local_init_op,\n      summary_op=summary_op,\n      saver=saver,\n      init_feed_dict=init_feed_dict,\n      init_fn=init_fn)\n  return scaffold\n\n\ndef _check_checkpoint_available(model_dir):\n  latest_path = tf.train.latest_checkpoint(model_dir)\n  if not latest_path:\n    raise ValueError(\n        'Could not find trained model in model_dir: {}.'.format(model_dir))\n\n\ndef _check_hooks_type(hooks):\n  \"\"\"Returns hooks if all are `SessionRunHook`, raises TypeError otherwise.\"\"\"\n  hooks = list(hooks or [])\n  for h in hooks:\n    if not isinstance(h, tf.compat.v1.train.SessionRunHook):\n      raise TypeError('Hooks must be a SessionRunHook, given: {}'.format(h))\n  return hooks\n\n\ndef _check_listeners_type(saving_listeners):\n  \"\"\"Check listeners type.\"\"\"\n  listeners = list(saving_listeners or [])\n  for l in listeners:\n    if not isinstance(l, tf.compat.v1.train.CheckpointSaverListener):\n      raise TypeError(\n          'saving_listeners must be a list of CheckpointSaverListener, '\n          'given: {}'.format(l))\n  return listeners\n\n\ndef _get_replica_device_setter(config):\n  \"\"\"Creates a replica device setter if required as a default `device_fn`.\n\n  `Estimator` uses `tf.train.ReplicaDeviceSetter` as a default device placer. It\n  sets the distributed related arguments such as number of `ps_replicas` based\n  on given `config`.\n\n  Args:\n    config: A `tf.estimator.RunConfig` instance.\n\n  Returns:\n    A replica device setter, or `None`.\n  \"\"\"\n  if config.task_type:\n    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)\n  else:\n    worker_device = '/job:worker'\n\n  if config.num_ps_replicas > 0:\n    return tf.compat.v1.train.replica_device_setter(\n        ps_tasks=config.num_ps_replicas,\n        worker_device=worker_device,\n        merge_devices=True,\n        ps_ops=list(device_setter.STANDARD_PS_OPS),\n        cluster=config.cluster_spec)\n  else:\n    return None\n\n\ndef _verify_model_fn_args(model_fn, params):\n  \"\"\"Verifies `model_fn` arguments.\"\"\"\n  args = set(function_utils.fn_args(model_fn))\n  if 'features' not in args:\n    raise ValueError('model_fn (%s) must include features argument.' % model_fn)\n  if params is not None and 'params' not in args:\n    raise ValueError('model_fn (%s) does not include params argument, '\n                     'but params (%s) is passed to Estimator.' %\n                     (model_fn, params))\n  if params is None and 'params' in args:\n    tf.compat.v1.logging.warn(\n        'Estimator\\'s model_fn (%s) includes params '\n        'argument, but params are not passed to Estimator.', model_fn)\n  non_valid_args = list(args - _VALID_MODEL_FN_ARGS)\n  if non_valid_args:\n    raise ValueError('model_fn (%s) has following not expected args: %s' %\n                     (model_fn, non_valid_args))\n\n\ndef _load_global_step_from_checkpoint_dir(checkpoint_dir):\n  try:\n    checkpoint_reader = tf.compat.v1.train.NewCheckpointReader(\n        tf.train.latest_checkpoint(checkpoint_dir))\n    return checkpoint_reader.get_tensor(tf.compat.v1.GraphKeys.GLOBAL_STEP)\n  except:  # pylint: disable=bare-except\n    return 0\n\n\ndef _extract_metric_update_ops(eval_dict, distribution=None):\n  \"\"\"Separate update operations from metric value operations.\"\"\"\n  update_ops = []\n  value_ops = {}\n  # Sort metrics lexicographically so graph is identical every time.\n  for name, value in sorted(six.iteritems(eval_dict)):\n    value_ops[name] = value[0]\n    update_ops.append(\n        distribution.group(value[1]) if distribution else value[1])\n\n  update_op = tf.group(*update_ops) if update_ops else None\n  return update_op, value_ops\n\n\ndef _dict_to_str(dictionary):\n  \"\"\"Get a `str` representation of a `dict`.\n\n  Args:\n    dictionary: The `dict` to be represented as `str`.\n\n  Returns:\n    A `str` representing the `dictionary`.\n  \"\"\"\n  return ', '.join('%s = %s' % (k, v)\n                   for k, v in sorted(six.iteritems(dictionary))\n                   if not isinstance(v, six.binary_type))\n\n\ndef _write_dict_to_summary(output_dir, dictionary, current_global_step):\n  \"\"\"Writes a `dict` into summary file in given output directory.\n\n  Args:\n    output_dir: `str`, directory to write the summary file in.\n    dictionary: the `dict` to be written to summary file.\n    current_global_step: `int`, the current global step.\n  \"\"\"\n  tf.compat.v1.logging.info('Saving dict for global step %d: %s',\n                            current_global_step, _dict_to_str(dictionary))\n  summary_writer = tf.compat.v1.summary.FileWriterCache.get(output_dir)\n  summary_proto = summary_pb2.Summary()\n  for key in dictionary:\n    if dictionary[key] is None:\n      continue\n    if key == 'global_step':\n      continue\n    if (isinstance(dictionary[key], np.float32) or\n        isinstance(dictionary[key], float)):\n      summary_proto.value.add(tag=key, simple_value=float(dictionary[key]))\n    elif (isinstance(dictionary[key], np.int64) or\n          isinstance(dictionary[key], np.int32) or\n          isinstance(dictionary[key], int)):\n      summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))\n    elif isinstance(dictionary[key], six.binary_type):\n      try:\n        summ = summary_pb2.Summary.FromString(dictionary[key])\n        for i, _ in enumerate(summ.value):\n          summ.value[i].tag = '%s/%d' % (key, i)\n        summary_proto.value.extend(summ.value)\n      except message.DecodeError:\n        tf.compat.v1.logging.warn(\n            'Skipping summary for %s, cannot parse string to Summary.', key)\n        continue\n    elif isinstance(dictionary[key], np.ndarray):\n      value = summary_proto.value.add()\n      value.tag = key\n      value.node_name = key\n      tensor_proto = tf.make_tensor_proto(dictionary[key])\n      value.tensor.CopyFrom(tensor_proto)\n      # pylint: disable=line-too-long\n      tf.compat.v1.logging.info(\n          'Summary for np.ndarray is not visible in Tensorboard by default. '\n          'Consider using a Tensorboard plugin for visualization (see '\n          'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md'\n          ' for more information).')\n      # pylint: enable=line-too-long\n    else:\n      tf.compat.v1.logging.warn(\n          'Skipping summary for %s, must be a float, np.float32, np.int64, '\n          'np.int32 or int or np.ndarray or a serialized string of Summary.',\n          key)\n  summary_writer.add_summary(summary_proto, current_global_step)\n  summary_writer.flush()\n\n\ndef _write_checkpoint_path_to_summary(output_dir, checkpoint_path,\n                                      current_global_step):\n  \"\"\"Writes `checkpoint_path` into summary file in the given output directory.\n\n  Args:\n    output_dir: `str`, directory to write the summary file in.\n    checkpoint_path: `str`, checkpoint file path to be written to summary file.\n    current_global_step: `int`, the current global step.\n  \"\"\"\n\n  checkpoint_path_tag = 'checkpoint_path'\n\n  tf.compat.v1.logging.info('Saving \\'%s\\' summary for global step %d: %s',\n                            checkpoint_path_tag, current_global_step,\n                            checkpoint_path)\n  summary_proto = summary_pb2.Summary()\n  summary_proto.value.add(\n      tag=checkpoint_path_tag,\n      tensor=tf.make_tensor_proto(checkpoint_path, dtype=tf.dtypes.string))\n  summary_writer = tf.compat.v1.summary.FileWriterCache.get(output_dir)\n  summary_writer.add_summary(summary_proto, current_global_step)\n  summary_writer.flush()\n\n\ndef _has_dataset_or_queue_runner(maybe_tensor):\n  \"\"\"Returns `True` if `Dataset` or `QueueRunner` has been used.\"\"\"\n  # Check TF dataset first. Here, we use a simple algorithm to check the top\n  # level Tensors only, which should be sufficient for most users.\n  tensors = [\n      x for x in tf.nest.flatten(maybe_tensor) if isinstance(x, tf.Tensor)\n  ]\n  if any([t.op.type == 'IteratorGetNext' for t in tensors]):\n    return True\n\n  # Now, check queue.\n  return tf.compat.v1.get_default_graph().get_collection(\n      tf.compat.v1.GraphKeys.QUEUE_RUNNERS)\n\n\nVocabInfo = tf.compat.v1.train.VocabInfo  # pylint: disable=invalid-name\nestimator_export('estimator.VocabInfo')(VocabInfo)\n\n\n@estimator_export('estimator.WarmStartSettings')\nclass WarmStartSettings(\n    collections.namedtuple('WarmStartSettings', [\n        'ckpt_to_initialize_from',\n        'vars_to_warm_start',\n        'var_name_to_vocab_info',\n        'var_name_to_prev_var_name',\n    ])):\n  \"\"\"Settings for warm-starting in `tf.estimator.Estimators`.\n\n  Example Use with canned `tf.estimator.DNNEstimator`:\n\n  ```\n  emb_vocab_file = tf.feature_column.embedding_column(\n      tf.feature_column.categorical_column_with_vocabulary_file(\n          \"sc_vocab_file\", \"new_vocab.txt\", vocab_size=100),\n      dimension=8)\n  emb_vocab_list = tf.feature_column.embedding_column(\n      tf.feature_column.categorical_column_with_vocabulary_list(\n          \"sc_vocab_list\", vocabulary_list=[\"a\", \"b\"]),\n      dimension=8)\n  estimator = tf.estimator.DNNClassifier(\n    hidden_units=[128, 64], feature_columns=[emb_vocab_file, emb_vocab_list],\n    warm_start_from=ws)\n  ```\n\n  where `ws` could be defined as:\n\n  Warm-start all weights in the model (input layer and hidden weights).\n  Either the directory or a specific checkpoint can be provided (in the case\n  of the former, the latest checkpoint will be used):\n\n  ```\n  ws = WarmStartSettings(ckpt_to_initialize_from=\"/tmp\")\n  ws = WarmStartSettings(ckpt_to_initialize_from=\"/tmp/model-1000\")\n  ```\n\n  Warm-start only the embeddings (input layer):\n\n  ```\n  ws = WarmStartSettings(ckpt_to_initialize_from=\"/tmp\",\n                         vars_to_warm_start=\".*input_layer.*\")\n  ```\n\n  Warm-start all weights but the embedding parameters corresponding to\n  `sc_vocab_file` have a different vocab from the one used in the current\n  model:\n\n  ```\n  vocab_info = tf.estimator.VocabInfo(\n      new_vocab=sc_vocab_file.vocabulary_file,\n      new_vocab_size=sc_vocab_file.vocabulary_size,\n      num_oov_buckets=sc_vocab_file.num_oov_buckets,\n      old_vocab=\"old_vocab.txt\"\n  )\n  ws = WarmStartSettings(\n      ckpt_to_initialize_from=\"/tmp\",\n      var_name_to_vocab_info={\n          \"input_layer/sc_vocab_file_embedding/embedding_weights\": vocab_info\n      })\n  ```\n\n  Warm-start only `sc_vocab_file` embeddings (and no other variables), which\n  have a different vocab from the one used in the current model:\n\n  ```\n  vocab_info = tf.estimator.VocabInfo(\n      new_vocab=sc_vocab_file.vocabulary_file,\n      new_vocab_size=sc_vocab_file.vocabulary_size,\n      num_oov_buckets=sc_vocab_file.num_oov_buckets,\n      old_vocab=\"old_vocab.txt\"\n  )\n  ws = WarmStartSettings(\n      ckpt_to_initialize_from=\"/tmp\",\n      vars_to_warm_start=None,\n      var_name_to_vocab_info={\n          \"input_layer/sc_vocab_file_embedding/embedding_weights\": vocab_info\n      })\n  ```\n\n  Warm-start all weights but the parameters corresponding to `sc_vocab_file`\n  have a different vocab from the one used in current checkpoint, and only\n  100 of those entries were used:\n\n  ```\n  vocab_info = tf.estimator.VocabInfo(\n      new_vocab=sc_vocab_file.vocabulary_file,\n      new_vocab_size=sc_vocab_file.vocabulary_size,\n      num_oov_buckets=sc_vocab_file.num_oov_buckets,\n      old_vocab=\"old_vocab.txt\",\n      old_vocab_size=100\n  )\n  ws = WarmStartSettings(\n      ckpt_to_initialize_from=\"/tmp\",\n      var_name_to_vocab_info={\n          \"input_layer/sc_vocab_file_embedding/embedding_weights\": vocab_info\n      })\n  ```\n\n  Warm-start all weights but the parameters corresponding to `sc_vocab_file`\n  have a different vocab from the one used in current checkpoint and the\n  parameters corresponding to `sc_vocab_list` have a different name from the\n  current checkpoint:\n\n  ```\n  vocab_info = tf.estimator.VocabInfo(\n      new_vocab=sc_vocab_file.vocabulary_file,\n      new_vocab_size=sc_vocab_file.vocabulary_size,\n      num_oov_buckets=sc_vocab_file.num_oov_buckets,\n      old_vocab=\"old_vocab.txt\",\n      old_vocab_size=100\n  )\n  ws = WarmStartSettings(\n      ckpt_to_initialize_from=\"/tmp\",\n      var_name_to_vocab_info={\n          \"input_layer/sc_vocab_file_embedding/embedding_weights\": vocab_info\n      },\n      var_name_to_prev_var_name={\n          \"input_layer/sc_vocab_list_embedding/embedding_weights\":\n              \"old_tensor_name\"\n      })\n  ```\n\n  Warm-start all TRAINABLE variables:\n\n  ```\n  ws = WarmStartSettings(ckpt_to_initialize_from=\"/tmp\",\n                         vars_to_warm_start=\".*\")\n  ```\n\n  Warm-start all variables (including non-TRAINABLE):\n\n  ```\n  ws = WarmStartSettings(ckpt_to_initialize_from=\"/tmp\",\n                         vars_to_warm_start=[\".*\"])\n  ```\n\n  Warm-start non-TRAINABLE variables \"v1\", \"v1/Momentum\", and \"v2\" but not\n  \"v2/momentum\":\n\n  ```\n  ws = WarmStartSettings(ckpt_to_initialize_from=\"/tmp\",\n                         vars_to_warm_start=[\"v1\", \"v2[^/]\"])\n  ```\n\n  Attributes:\n    ckpt_to_initialize_from: [Required] A string specifying the directory with\n      checkpoint file(s) or path to checkpoint from which to warm-start the\n      model parameters.\n    vars_to_warm_start: [Optional] One of the following:\n\n      * A regular expression (string) that captures which variables to\n        warm-start (see tf.compat.v1.get_collection).  This expression will only\n        consider variables in the TRAINABLE_VARIABLES collection -- if you need\n        to warm-start non_TRAINABLE vars (such as optimizer accumulators or\n        batch norm statistics), please use the below option.\n      * A list of strings, each a regex scope provided to\n        tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see\n        tf.compat.v1.get_collection).  For backwards compatibility reasons, this\n        is separate from the single-string argument type.\n      * A list of Variables to warm-start.  If you do not have access to the\n        `Variable` objects at the call site, please use the above option.\n      * `None`, in which case only TRAINABLE variables specified in\n        `var_name_to_vocab_info` will be warm-started.\n\n      Defaults to `'.*'`, which warm-starts all variables in the\n      TRAINABLE_VARIABLES collection. Note that this excludes variables such as\n      accumulators and moving statistics from batch norm.\n    var_name_to_vocab_info: [Optional] Dict of variable names (strings) to\n      `tf.estimator.VocabInfo`. The variable names should be \"full\" variables,\n      not the names of the partitions.  If not explicitly provided, the variable\n      is assumed to have no (changes to) vocabulary.\n    var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to\n      name of the previously-trained variable in `ckpt_to_initialize_from`. If\n      not explicitly provided, the name of the variable is assumed to be same\n      between previous checkpoint and current model.  Note that this has no\n      effect on the set of variables that is warm-started, and only controls\n      name mapping (use `vars_to_warm_start` for controlling what variables to\n      warm-start).\n  \"\"\"\n\n  def __new__(cls,\n              ckpt_to_initialize_from,\n              vars_to_warm_start='.*',\n              var_name_to_vocab_info=None,\n              var_name_to_prev_var_name=None):\n    if not ckpt_to_initialize_from:\n      raise ValueError(\n          '`ckpt_to_initialize_from` MUST be set in WarmStartSettings')\n    return super(WarmStartSettings, cls).__new__(\n        cls,\n        ckpt_to_initialize_from,\n        vars_to_warm_start,\n        var_name_to_vocab_info or {},\n        var_name_to_prev_var_name or {},\n    )\n\n\ndef _get_default_warm_start_settings(warm_start_from):\n  \"\"\"Returns default `tf.estimator.WarmStartSettings`.\n\n  Args:\n    warm_start_from: Either a string representing the filepath of a checkpoint\n      or `SavedModel` to initialize from, or an instance of\n      `tf.estimator.WarmStartSettings`.\n\n  Returns:\n    Either None or an instance of `WarmStartSettings`.\n\n  Raises:\n    ValueError: If `warm_start_from` is not `None` but is neither a string nor\n    an instance of `WarmStartSettings`.\n  \"\"\"\n  if warm_start_from is None:\n    return None\n  if isinstance(warm_start_from, (six.string_types, six.binary_type)):\n    # Infer that this is a SavedModel if export_path +\n    # 'variables/variables.index' exists, and if so, construct the\n    # WarmStartSettings pointing to the variables path\n    # (export_path + 'variables/variables').\n    if tf.compat.v1.gfile.Exists(\n        os.path.join(\n            path_helpers.get_variables_dir(warm_start_from),\n            tf.compat.as_text('variables.index'))):\n      tf.compat.v1.logging.info('Warm-starting from a SavedModel')\n      return WarmStartSettings(\n          ckpt_to_initialize_from=path_helpers.get_variables_path(\n              warm_start_from))\n    return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)\n  elif isinstance(warm_start_from, WarmStartSettings):\n    return warm_start_from\n  else:\n    raise ValueError('warm_start_from must be a string or a WarmStartSettings, '\n                     'instead got {}'.format(type(warm_start_from)))\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/estimator_export.py",
    "content": "# Copyright 2023 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utilities for exporting TensorFlow Estimator symbols to the API.\n\nExporting a function or a class:\n\nTo export a function or a class use the estimator_export decorator. For e.g.:\n```python\n@estimator_export('foo', 'bar.foo')\ndef foo(...):\n  ...\n```\n\nIf a function is assigned to a variable, you can export it by calling\nestimator_export explicitly. For e.g.:\n```python\nfoo = get_foo(...)\nestimator_export('foo', 'bar.foo')(foo)\n```\n\n\nExporting a constant\n```python\nfoo = 1\nestimator_export('consts.foo').export_constant(__name__, 'foo')\n```\n\"\"\"\nfrom collections.abc import Sequence\nfrom typing import Optional, TypeVar\n\nfrom tensorflow.python.util import deprecation\nfrom tensorflow.python.util import tf_export\n\nT = TypeVar('T')\n\nESTIMATOR_API_NAME = 'estimator'\n\n\n# pylint: disable=protected-access\nif ESTIMATOR_API_NAME not in tf_export.API_ATTRS:\n  tf_export.API_ATTRS[ESTIMATOR_API_NAME] = tf_export._Attributes(\n      '_estimator_api_names', '_estimator_api_constants'\n  )\nif ESTIMATOR_API_NAME not in tf_export.API_ATTRS_V1:\n  tf_export.API_ATTRS_V1[ESTIMATOR_API_NAME] = tf_export._Attributes(\n      '_estimator_api_names_v1', '_estimator_api_constants_v1'\n  )\n# pylint: enable=protected-access\n\n\nclass estimator_export(tf_export.api_export):  # pylint: disable=invalid-name\n  \"\"\"Provides ways to export symbols to the TensorFlow Estimator API.\"\"\"\n\n  def __init__(self, *args: str, v1: Optional[Sequence[str]] = None):\n    \"\"\"Export under the names *args (first one is considered canonical).\n\n    All symbols exported by this decorator are exported under the `estimator`\n    API name.\n\n    Args:\n      *args: API names in dot delimited format.\n      v1: Names for the TensorFlow V1 API. If not set, we will use V2 API names\n        both for TensorFlow V1 and V2 APIs.\n    \"\"\"\n    super().__init__(*args, api_name=ESTIMATOR_API_NAME, v1=v1)\n\n  def __call__(self, func: T) -> T:\n    \"\"\"Calls this decorator.\n\n    Args:\n      func: decorated symbol (function or class).\n\n    Returns:\n      The input function with _tf_api_names attribute set and marked as\n      deprecated.\n    \"\"\"\n    func = deprecation.deprecated(None, 'Use tf_keras instead.')(func)\n    return super().__call__(func)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/estimator_export_test.py",
    "content": "# Copyright 2023 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"estimator_export tests.\"\"\"\n\nimport sys\nimport tensorflow as tf\n\nfrom tensorflow.python.platform import tf_logging as logging\nfrom tensorflow.python.util import tf_export\n# pylint: disable=g-deprecated-tf-checker\nfrom tensorflow_estimator.python.estimator import estimator_export\n\n\nclass TestClass(object):\n  pass\n\n\nclass ValidateExportTest(tf.test.TestCase):\n  \"\"\"Tests for estimator_export class.\"\"\"\n\n  def setUp(self):\n    super().setUp()\n    self._modules = []\n\n  def tearDown(self):\n    super().tearDown()\n    for name in self._modules:\n      del sys.modules[name]\n    self._modules = []\n    if hasattr(TestClass, '_estimator_api_names'):\n      del TestClass._estimator_api_names\n    if hasattr(TestClass, '_estimator_api_names_v1'):\n      del TestClass._estimator_api_names_v1\n\n  @tf.compat.v1.test.mock.patch.object(\n      logging, 'warning', autospec=True\n  )\n  def testExportDeprecated(self, mock_warning):\n    export_decorator = estimator_export.estimator_export('estimator.TestClass')\n    export_decorator(TestClass)\n\n    # Deprecation should trigger a runtime warning\n    TestClass()\n    self.assertEqual(1, mock_warning.call_count)\n    # Deprecation should only warn once, upon first call\n    TestClass()\n    self.assertEqual(1, mock_warning.call_count)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/estimator_lib.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Estimator: High level tools for working with models.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# pylint: disable=unused-import,line-too-long,wildcard-import\nfrom tensorflow_estimator.python.estimator.canned.baseline import BaselineClassifier\nfrom tensorflow_estimator.python.estimator.canned.baseline import BaselineEstimator\nfrom tensorflow_estimator.python.estimator.canned.baseline import BaselineRegressor\nfrom tensorflow_estimator.python.estimator.canned.dnn import dnn_logit_fn_builder\nfrom tensorflow_estimator.python.estimator.canned.dnn import DNNClassifier\nfrom tensorflow_estimator.python.estimator.canned.dnn import DNNEstimator\nfrom tensorflow_estimator.python.estimator.canned.dnn import DNNRegressor\nfrom tensorflow_estimator.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedClassifier\nfrom tensorflow_estimator.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedEstimator\nfrom tensorflow_estimator.python.estimator.canned.dnn_linear_combined import DNNLinearCombinedRegressor\nfrom tensorflow_estimator.python.estimator.canned.kmeans import KMeansClustering\nfrom tensorflow_estimator.python.estimator.canned.linear import linear_logit_fn_builder\nfrom tensorflow_estimator.python.estimator.canned.linear import LinearClassifier\nfrom tensorflow_estimator.python.estimator.canned.linear import LinearEstimator\nfrom tensorflow_estimator.python.estimator.canned.linear import LinearRegressor\nfrom tensorflow_estimator.python.estimator.canned.parsing_utils import classifier_parse_example_spec\nfrom tensorflow_estimator.python.estimator.canned.parsing_utils import regressor_parse_example_spec\nfrom tensorflow_estimator.python.estimator.canned.rnn import RNNClassifier\nfrom tensorflow_estimator.python.estimator.canned.rnn import RNNEstimator\nfrom tensorflow_estimator.python.estimator.early_stopping import *\nfrom tensorflow_estimator.python.estimator.estimator import Estimator\nfrom tensorflow_estimator.python.estimator.estimator import VocabInfo\nfrom tensorflow_estimator.python.estimator.estimator import WarmStartSettings\nfrom tensorflow_estimator.python.estimator.export import export_lib as export\nfrom tensorflow_estimator.python.estimator.exporter import Exporter\nfrom tensorflow_estimator.python.estimator.exporter import FinalExporter\nfrom tensorflow_estimator.python.estimator.exporter import LatestExporter\nfrom tensorflow_estimator.python.estimator.extenders import add_metrics\nfrom tensorflow_estimator.python.estimator.head.base_head import Head\nfrom tensorflow_estimator.python.estimator.head.binary_class_head import BinaryClassHead\nfrom tensorflow_estimator.python.estimator.head.multi_class_head import MultiClassHead\nfrom tensorflow_estimator.python.estimator.head.multi_head import MultiHead\nfrom tensorflow_estimator.python.estimator.head.multi_label_head import MultiLabelHead\nfrom tensorflow_estimator.python.estimator.head.regression_head import LogisticRegressionHead\nfrom tensorflow_estimator.python.estimator.head.regression_head import PoissonRegressionHead\nfrom tensorflow_estimator.python.estimator.head.regression_head import RegressionHead\nfrom tensorflow_estimator.python.estimator.hooks import basic_session_run_hooks\nfrom tensorflow_estimator.python.estimator.hooks import hooks\nfrom tensorflow_estimator.python.estimator.hooks import session_run_hook\nfrom tensorflow_estimator.python.estimator.inputs import inputs\nfrom tensorflow_estimator.python.estimator.keras_lib import model_to_estimator\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\nfrom tensorflow_estimator.python.estimator.model_fn import call_logit_fn\nfrom tensorflow_estimator.python.estimator.model_fn import EstimatorSpec\nfrom tensorflow_estimator.python.estimator.run_config import RunConfig\nfrom tensorflow_estimator.python.estimator.tpu.tpu_estimator import TPUEstimator\nfrom tensorflow_estimator.python.estimator.training import EvalSpec\nfrom tensorflow_estimator.python.estimator.training import train_and_evaluate\nfrom tensorflow_estimator.python.estimator.training import TrainSpec\n\n# pylint: enable=unused-import,line-too-long,wildcard-import\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/estimator_test.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for Estimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport functools\nimport glob\nimport json\nimport os\nimport socket\nimport tempfile\n\nimport numpy as np\nimport six\nimport tensorflow.compat.v1 as tf\nfrom google.protobuf import text_format\n\nfrom absl.testing import parameterized\nfrom tensorflow.core.protobuf import rewriter_config_pb2\nfrom tensorflow.python.framework import combinations\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.lib.io import file_io\nfrom tensorflow.python.ops import control_flow_ops\nfrom tensorflow.python.ops.random_ops import random_uniform\nfrom tensorflow.python.platform import tf_logging as logging\nfrom tensorflow.python.platform import gfile\nfrom tensorflow.python.profiler import profiler_v2 as profiler\nfrom tensorflow.python.saved_model import loader_impl\nfrom tensorflow.python.saved_model import path_helpers\nfrom tensorflow.python.saved_model import tag_constants\nfrom tensorflow.python.training import checkpoint_state_pb2\nfrom tensorflow.python.training import saver_test_utils\nfrom tensorflow.python.training import training\nfrom tensorflow.python.util import function_utils\nfrom tensorflow_estimator.python.estimator import training as estimator_training\nfrom tensorflow_estimator.python.estimator import estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator import run_config\nfrom tensorflow_estimator.python.estimator.export import export_lib\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n_TMP_DIR = '/tmp'\n_ANOTHER_TMP_DIR = '/another_tmp'\n\n\ndef dummy_model_fn(features, labels, params):\n  _, _, _ = features, labels, params\n\n\ndef summaries_with_matching_keyword(keyword, dir_):\n  \"\"\"Yields summary protos matching given keyword from event file.\"\"\"\n\n  tf.summary.FileWriterCache.clear()\n\n  event_paths = glob.glob(os.path.join(dir_, 'events*'))\n  for event in tf.train.summary_iterator(event_paths[-1]):\n    if event.summary is not None:\n      for value in event.summary.value:\n        if keyword in value.tag:\n          yield event.summary\n\n\ndef check_eventfile_for_keyword(keyword, dir_):\n  \"\"\"Checks event files for the keyword.\"\"\"\n  return any(summaries_with_matching_keyword(keyword, dir_))\n\n\ndef get_mock_saver():\n  real_saver = tf.train.Saver()\n  return tf.test.mock.Mock(wraps=real_saver, saver_def=real_saver.saver_def)\n\n\nclass EstimatorInheritanceConstraintTest(tf.test.TestCase):\n  \"\"\"Tests that sub classes cannot override methods of Estimator.\"\"\"\n\n  @property\n  def random_estimator(self):\n    switch = np.random.random()\n    return estimator.EstimatorV2 if switch > 0.5 else estimator.EstimatorV2\n\n  def test_override_a_method(self):\n\n    class _Estimator(self.random_estimator):\n\n      def __init__(self):\n        super(_Estimator, self).__init__(model_fn=dummy_model_fn)\n\n      def predict(self, input_fn, predict_keys=None, hooks=None):\n        pass\n\n    with self.assertRaisesRegexp(\n        ValueError, 'cannot override members of Estimator.*predict'):\n      _Estimator()\n\n  def test_extension_of_api_is_ok(self):\n\n    class _Estimator(self.random_estimator):\n\n      def __init__(self):\n        super(_Estimator, self).__init__(model_fn=dummy_model_fn)\n\n      def predict_proba(self, input_fn, predict_keys=None, hooks=None):\n        pass\n\n    _Estimator()\n\n  def test_override_allowed_method(self):\n\n    class _Estimator(self.random_estimator):\n\n      def __init__(self):\n        super(_Estimator, self).__init__(model_fn=dummy_model_fn)\n\n      def _tf_api_names(self):\n        pass\n\n    _Estimator()\n\n\nclass EstimatorConstructorTest(tf.test.TestCase):\n\n  def test_config_must_be_a_run_config(self):\n    with self.assertRaisesRegexp(ValueError, 'an instance of `RunConfig`'):\n      estimator.EstimatorV2(model_fn=None, config='NotARunConfig')\n\n  def test_model_fn_must_be_provided(self):\n    with self.assertRaisesRegexp(ValueError, 'model_fn.* must be'):\n      estimator.EstimatorV2(model_fn=None)\n\n  def test_property_accessors(self):\n\n    def model_fn(features, labels, params):\n      _, _, _ = features, labels, params\n\n    class FakeConfig(run_config.RunConfig):\n      pass\n\n    params = {'hidden_layers': [3, 4]}\n    est = estimator.EstimatorV2(\n        model_fn=model_fn, model_dir='bla', config=FakeConfig(), params=params)\n    self.assertTrue(isinstance(est.config, FakeConfig))\n    self.assertEqual(params, est.params)\n    self.assertEqual('bla', est.model_dir)\n\n  def test_default_config(self):\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n\n    est = estimator.EstimatorV2(model_fn=model_fn)\n    self.assertTrue(isinstance(est.config, run_config.RunConfig))\n    self.assertTrue(est._session_config.allow_soft_placement)\n    rewrite_options = est._session_config.graph_options.rewrite_options\n    self.assertEqual(rewrite_options.meta_optimizer_iterations,\n                     rewriter_config_pb2.RewriterConfig.ONE)\n\n  def test_default_model_dir(self):\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n\n    with tf.test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):\n      est = estimator.EstimatorV2(model_fn=model_fn)\n      self.assertEqual(_TMP_DIR, est.config.model_dir)\n      self.assertEqual(_TMP_DIR, est.model_dir)\n\n  def test_model_dir_in_constructor(self):\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n\n    est = estimator.EstimatorV2(model_fn=model_fn, model_dir=_TMP_DIR)\n    self.assertEqual(_TMP_DIR, est.config.model_dir)\n    self.assertEqual(_TMP_DIR, est.model_dir)\n\n  def test_empty_model_dir(self):\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n\n    with tf.test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):\n      est = estimator.EstimatorV2(model_fn=model_fn, model_dir='')\n      self.assertEqual(_TMP_DIR, est.config.model_dir)\n      self.assertEqual(_TMP_DIR, est.model_dir)\n\n  def test_model_dir_in_run_config(self):\n\n    class FakeConfig(run_config.RunConfig):\n\n      @property\n      def model_dir(self):\n        return _TMP_DIR\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n\n    est = estimator.EstimatorV2(model_fn=model_fn, config=FakeConfig())\n    self.assertEqual(_TMP_DIR, est.config.model_dir)\n    self.assertEqual(_TMP_DIR, est.model_dir)\n\n  def test_same_model_dir_in_constructor_and_run_config(self):\n\n    class FakeConfig(run_config.RunConfig):\n\n      @property\n      def model_dir(self):\n        return _TMP_DIR\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n\n    est = estimator.EstimatorV2(\n        model_fn=model_fn, config=FakeConfig(), model_dir=_TMP_DIR)\n    self.assertEqual(_TMP_DIR, est.config.model_dir)\n    self.assertEqual(_TMP_DIR, est.model_dir)\n\n  def test_different_model_dir_in_constructor_and_run_config(self):\n\n    class FakeConfig(run_config.RunConfig):\n\n      @property\n      def model_dir(self):\n        return _TMP_DIR\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n\n    with self.assertRaisesRegexp(\n        ValueError,\n        '`model_dir` are set both in constructor and `RunConfig`, but '\n        'with different values'):\n      estimator.EstimatorV2(\n          model_fn=model_fn, config=FakeConfig(), model_dir=_ANOTHER_TMP_DIR)\n\n  def test_model_fn_args_must_include_features(self):\n\n    def model_fn(x, labels):\n      _, _ = x, labels\n\n    with self.assertRaisesRegexp(ValueError, 'features'):\n      estimator.EstimatorV2(model_fn=model_fn)\n\n  def test_model_fn_args_labels_is_optional(self):\n\n    def model_fn(features):\n      _ = features\n\n    estimator.EstimatorV2(model_fn=model_fn)\n\n  def test_if_params_provided_then_model_fn_should_accept_it(self):\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n\n    estimator.EstimatorV2(model_fn=model_fn)\n    with self.assertRaisesRegexp(ValueError, 'params'):\n      estimator.EstimatorV2(model_fn=model_fn, params={'hidden_layers': 4})\n\n  def test_internal_params_is_a_deepcopy(self):\n\n    def model_fn(features, labels, params):\n      _, _, _ = features, labels, params\n\n    params = {'hidden_layers': 4}\n    est = estimator.EstimatorV2(model_fn=model_fn, params=params)\n\n    params['hidden_layers'] = 5\n    self.assertEqual(4, est.params['hidden_layers'])\n\n  def test_not_known_model_fn_args(self):\n\n    def model_fn(features, labels, something):\n      _, _, _ = features, labels, something\n\n    with self.assertRaisesRegexp(ValueError, 'something'):\n      estimator.EstimatorV2(model_fn=model_fn)\n\n  def test_not_known_model_fn_args_handled_by_lambda(self):\n\n    def model_fn(features, labels, something):\n      _, _, _ = features, labels, something\n\n    new_model_fn = lambda features, labels: model_fn(  # pylint: disable=g-long-lambda\n        features, labels, 'something')\n    estimator.EstimatorV2(model_fn=new_model_fn)\n\n  def test_if_model_fn_is_a_member_function_of_a_class(self):\n\n    class ModelFnClass(object):\n\n      def __init__(self):\n        estimator.EstimatorV2(model_fn=self.model_fn)\n\n      def model_fn(self, features, labels, mode):\n        _, _, _ = features, labels, mode\n\n    ModelFnClass()\n\n  def test_model_fn_property_binds_params(self):\n\n    def model_fn(features, labels, mode, config, params):\n      _, _, _, _, _ = features, labels, mode, config, params\n\n    est = estimator.EstimatorV2(model_fn=model_fn)\n    model_fn_args = function_utils.fn_args(est.model_fn)\n    self.assertEqual(\n        set(['features', 'labels', 'mode', 'config']), set(model_fn_args))\n\n  def test_model_fn_property_returns_fixed_signature(self):\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n\n    est = estimator.EstimatorV2(model_fn=model_fn)\n    model_fn_args = function_utils.fn_args(est.model_fn)\n    self.assertEqual(\n        set(['features', 'labels', 'mode', 'config']), set(model_fn_args))\n\n\ndef dummy_input_fn():\n  return ({'x': tf.constant([[1], [1]])}, tf.constant([[1], [1]]))\n\n\ndef model_fn_global_step_incrementer(features, labels, mode):\n  _, _ = features, labels\n  global_step = tf.train.get_global_step()\n  return model_fn_lib.EstimatorSpec(\n      mode, loss=tf.constant(1.), train_op=tf.assign_add(global_step, 1))\n\n\ndef assert_features_op(expected_features, actual_features):\n  return [\n      tf.debugging.assert_equal(\n          expected_features[k], actual_features[k], name='assert_%s' % k)\n      for k in expected_features\n  ]\n\n\ndef _estimator_spec(expected_features, expected_labels, actual_features,\n                    actual_labels, mode):\n  assert_ops = tuple(\n      assert_features_op(expected_features, actual_features) + [\n          tf.debugging.assert_equal(\n              expected_labels, actual_labels, name='assert_labels')\n      ])\n  global_step = tf.train.get_global_step()\n  with tf.control_dependencies(assert_ops):\n    return model_fn_lib.EstimatorSpec(\n        mode=mode,\n        predictions=tf.constant(0.),\n        loss=tf.constant(0.),\n        train_op=tf.assign_add(global_step, 1))\n\n\ndef _make_input_fn(features, labels):\n\n  def _input_fn():\n    return {k: tf.constant(v) for k, v in six.iteritems(features)\n           }, tf.constant(labels)\n\n  return _input_fn\n\n\nclass EstimatorTrainTest(tf.test.TestCase):\n\n  def test_callable_model_fn(self):\n    expected_features = {'x': 42., 'y': 43.}\n    expected_labels = 44.\n\n    model_fn_call_count = [0]\n\n    test_self = self\n\n    class ModelFn(object):\n\n      def __call__(self, features, labels):\n        model_fn_call_count[0] += 1\n        test_self.assertItemsEqual(expected_features.keys(), features.keys())\n        return _estimator_spec(expected_features, expected_labels, features,\n                               labels, ModeKeys.TRAIN)\n\n    with self.assertRaisesRegexp(ValueError, 'does not include params'):\n      estimator.EstimatorV2(model_fn=ModelFn(), params={'a': 'b'})\n    est = estimator.EstimatorV2(\n        model_fn=ModelFn(), config=run_config.RunConfig())\n    self.assertEqual(0, model_fn_call_count[0])\n    est.train(\n        input_fn=_make_input_fn(expected_features, expected_labels), steps=1)\n    self.assertEqual(1, model_fn_call_count[0])\n\n  def test_callable_input_fn(self):\n    expected_mode = ModeKeys.TRAIN\n    expected_params = {'batch_size': 10}\n    expected_config = run_config.RunConfig().replace(tf_random_seed=4321)\n    input_fn_call_count = [0]\n\n    def _model_fn(features, labels, mode, params, config):\n      del params, config\n      return model_fn_global_step_incrementer(features, labels, mode)\n\n    test_self = self\n\n    class InputFn(object):\n\n      def __call__(self, mode, params, config):\n        input_fn_call_count[0] += 1\n        test_self.assertEqual(expected_mode, mode)\n        test_self.assertEqual(expected_params, params)\n        test_self.assertEqual(4321, config.tf_random_seed)\n        return dummy_input_fn()\n\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn, params=expected_params, config=expected_config)\n    self.assertEqual(0, input_fn_call_count[0])\n    est.train(InputFn(), steps=1)\n    self.assertEqual(1, input_fn_call_count[0])\n\n  def test_nested_input_fn(self):\n    expected_params = {'batch_size': 10}\n\n    def _input_fn():\n      dataset_features = tf.data.Dataset.from_tensor_slices(\n          (random_uniform([4]),\n           random_uniform([4, 100], maxval=100, dtype=tf.dtypes.int32)))\n      dataset_labels = tf.data.Dataset.from_tensor_slices(\n          random_uniform([4, 10]))\n      dataset = tf.data.Dataset.zip((dataset_features, dataset_labels))\n      dataset = dataset.repeat(-1)\n      iterator = tf.data.make_initializable_iterator(dataset)\n      return iterator.get_next()\n\n    def _model_fn(features, labels, mode, params, config):\n      del params, config\n      return model_fn_global_step_incrementer(features, labels, mode)\n\n    expected_config = run_config.RunConfig().replace(tf_random_seed=4321)\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn, params=expected_params, config=expected_config)\n    est.train(_input_fn, steps=4)\n\n  def test_input_fn_args(self):\n    expected_mode = ModeKeys.TRAIN\n    expected_params = {'batch_size': 10}\n    expected_config = run_config.RunConfig().replace(tf_random_seed=4321)\n    input_fn_call_count = [0]\n\n    def _model_fn(features, labels, mode, params, config):\n      del params, config\n      return model_fn_global_step_incrementer(features, labels, mode)\n\n    def _input_fn(mode, params, config):\n      input_fn_call_count[0] += 1\n      self.assertEqual(expected_mode, mode)\n      self.assertEqual(expected_params, params)\n      self.assertEqual(4321, config.tf_random_seed)\n      return dummy_input_fn()\n\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn, params=expected_params, config=expected_config)\n    self.assertEqual(0, input_fn_call_count[0])\n    est.train(_input_fn, steps=1)\n    self.assertEqual(1, input_fn_call_count[0])\n\n  def test_minimal_model_fn_args(self):\n    expected_features = {'x': 4, 'y': 5}\n\n    def _input_fn():\n      return expected_features\n\n    model_fn_call_count = [0]\n\n    def _model_fn(features):\n      model_fn_call_count[0] += 1\n      self.assertItemsEqual(expected_features.keys(), features.keys())\n      with tf.control_dependencies(\n          assert_features_op(expected_features, features)):\n        return model_fn_lib.EstimatorSpec(\n            mode=None,\n            predictions=tf.constant(0.),\n            loss=tf.constant(0.),\n            train_op=tf.assign_add(tf.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    self.assertEqual(0, model_fn_call_count[0])\n    est.train(input_fn=_input_fn, steps=1)\n    self.assertEqual(1, model_fn_call_count[0])\n\n  def test_labels_should_be_none_if_model_fn_does_not_use_labels(self):\n\n    def _input_fn_with_labels():\n      return {'x': 4, 'y': 5}, [4]\n\n    def _model_fn(features):\n      _ = features\n      return model_fn_lib.EstimatorSpec(\n          mode=None,\n          predictions=tf.constant(0.),\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    with self.assertRaisesRegexp(ValueError, 'model_fn does not take labels'):\n      est.train(input_fn=_input_fn_with_labels, steps=1)\n\n  def test_input_fn_len_should_be_2_if_tuple_or_list(self):\n\n    def _input_fn():\n      return 4, 5, 6\n\n    def _model_fn(features):\n      _ = features\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    with self.assertRaisesRegexp(ValueError, 'len 2 tuple'):\n      est.train(input_fn=_input_fn, steps=1)\n\n  def test_all_model_fn_args(self):\n    expected_features = {'x': 42., 'y': 43.}\n    expected_labels = 44.\n    expected_params = {'some_param': 'some_value'}\n    expected_config = run_config.RunConfig()\n    expected_config.i_am_test = True\n\n    # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments\n    # doesn't work with mock fns.\n    model_fn_call_count = [0]\n\n    # Note that args are all passed by keyword, so can be in any order.\n    def _model_fn(mode, params, features, labels, config):\n      model_fn_call_count[0] += 1\n      self.assertItemsEqual(expected_features.keys(), features.keys())\n      self.assertEqual(ModeKeys.TRAIN, mode)\n      self.assertEqual(expected_params, params)\n      self.assertTrue(config.i_am_test)\n      return _estimator_spec(expected_features, expected_labels, features,\n                             labels, mode)\n\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn, params=expected_params, config=expected_config)\n    self.assertEqual(0, model_fn_call_count[0])\n    est.train(\n        input_fn=_make_input_fn(expected_features, expected_labels), steps=1)\n    self.assertEqual(1, model_fn_call_count[0])\n\n  def test_partial_model_fn_args(self):\n    expected_features = {'x': 42., 'y': 43.}\n    expected_labels = 44.\n    expected_params = {'some_param': 'some_value'}\n    expected_config = run_config.RunConfig()\n    expected_config.i_am_test = True\n    expected_foo = 45.\n    expected_bar = 46.\n\n    # TODO(ptucker): We have to roll our own mock since Estimator._get_arguments\n    # doesn't work with mock fns.\n    model_fn_call_count = [0]\n\n    def _model_fn(features, labels, foo, mode, params, config, bar):\n      model_fn_call_count[0] += 1\n      self.assertEqual(expected_foo, foo)\n      self.assertEqual(expected_bar, bar)\n      self.assertItemsEqual(expected_features.keys(), features.keys())\n      self.assertEqual(ModeKeys.TRAIN, mode)\n      self.assertEqual(expected_params, params)\n      self.assertTrue(config.i_am_test)\n      return _estimator_spec(expected_features, expected_labels, features,\n                             labels, mode)\n\n    partial_model_fn = functools.partial(\n        _model_fn, foo=expected_foo, bar=expected_bar)\n\n    est = estimator.EstimatorV2(\n        model_fn=partial_model_fn,\n        params=expected_params,\n        config=expected_config)\n    self.assertEqual(0, model_fn_call_count[0])\n    est.train(\n        input_fn=_make_input_fn(expected_features, expected_labels), steps=1)\n    self.assertEqual(1, model_fn_call_count[0])\n\n  def test_model_fn_must_return_estimator_spec(self):\n\n    def model_fn(features, labels):\n      _, _ = features, labels\n      return 'NotGoodNotGood'\n\n    est = estimator.EstimatorV2(model_fn=model_fn)\n    with self.assertRaisesRegexp(ValueError, 'EstimatorSpec'):\n      est.train(dummy_input_fn, steps=1)\n\n  def test_run_train_op_and_saves_at_the_end(self):\n    est = estimator.EstimatorV2(model_fn=model_fn_global_step_incrementer)\n    est.train(dummy_input_fn, steps=5)\n    self.assertEqual(\n        5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))\n\n  def test_loss_summary(self):\n    est = estimator.EstimatorV2(\n        model_fn=model_fn_global_step_incrementer,\n        config=run_config.RunConfig(save_summary_steps=1))\n    est.train(dummy_input_fn, steps=1)\n\n    # Make sure nothing is stuck in limbo.\n    tf.summary.FileWriterCache.clear()\n\n    if check_eventfile_for_keyword('loss', est.model_dir):\n      return\n    self.fail('{} should be part of reported summaries.'.format('loss'))\n\n  def test_latest_checkpoint(self):\n    est = estimator.EstimatorV2(model_fn=model_fn_global_step_incrementer)\n    self.assertIsNone(est.latest_checkpoint())\n    est.train(dummy_input_fn, steps=5)\n    self.assertIsNotNone(est.latest_checkpoint())\n    self.assertTrue(est.latest_checkpoint().startswith(est.model_dir))\n\n  def test_steps_and_saves_reloads(self):\n    est = estimator.EstimatorV2(model_fn=model_fn_global_step_incrementer)\n    est.train(dummy_input_fn, steps=5)\n    self.assertEqual(\n        5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))\n    est.train(dummy_input_fn, steps=5)\n    self.assertEqual(\n        10, estimator._load_global_step_from_checkpoint_dir(est.model_dir))\n\n  def test_warm_starts(self):\n\n    def _make_model_fn(x):\n\n      def _variable_creating_model_fn(features, labels, mode):\n        _, _ = features, labels\n        tf.get_variable('x', initializer=x)\n        global_step = tf.train.get_global_step()\n        return model_fn_lib.EstimatorSpec(\n            mode, loss=tf.constant(1.), train_op=tf.assign_add(global_step, 1))\n\n      return _variable_creating_model_fn\n\n    est = estimator.EstimatorV2(model_fn=_make_model_fn(42.))\n    est.train(dummy_input_fn, steps=10)\n\n    warm_started_est = estimator.EstimatorV2(\n        model_fn=_make_model_fn(36.), warm_start_from=est.model_dir)\n    warm_started_est.train(dummy_input_fn, steps=5)\n    # warm_start is called after the model_fn, so x should have the value\n    # from the checkpoint.\n    self.assertEqual(42., warm_started_est.get_variable_value('x'))\n    # global_step should not be warm-started.\n    self.assertEqual(\n        5,\n        estimator._load_global_step_from_checkpoint_dir(\n            warm_started_est.model_dir))\n\n  @test_util.run_v1_only('b/119219961')\n  def test_warm_starts_from_savedmodel(self):\n\n    def _make_model_fn(x):\n\n      def _variable_creating_and_export_model_fn(features, labels, mode):\n        _, _ = features, labels\n        tf.get_variable('x', initializer=x)\n        global_step = tf.train.get_global_step()\n        return model_fn_lib.EstimatorSpec(\n            mode,\n            predictions={'y': tf.constant(1.0)},\n            loss=tf.constant(1.),\n            train_op=tf.assign_add(global_step, 1),\n            export_outputs={\n                'test':\n                    export_lib.ClassificationOutput(\n                        tf.constant([4.2]), tf.constant(['label']))\n            })\n\n      return _variable_creating_and_export_model_fn\n\n    est = estimator.EstimatorV2(model_fn=_make_model_fn(42.))\n    est.train(dummy_input_fn, steps=10)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n    tmpdir = tempfile.mkdtemp()\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        serving_input_receiver_fn)\n\n    warm_started_est = estimator.EstimatorV2(\n        model_fn=_make_model_fn(36.), warm_start_from=export_dir)\n    warm_started_est.train(dummy_input_fn, steps=5)\n    # warm_start is called after the model_fn, so x should have the value\n    # from the SavedModel.\n    self.assertEqual(42., warm_started_est.get_variable_value('x'))\n\n  def test_max_step(self):\n    est = estimator.EstimatorV2(model_fn=model_fn_global_step_incrementer)\n    est.train(dummy_input_fn, max_steps=5)\n    self.assertEqual(\n        5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))\n    est.train(dummy_input_fn, max_steps=5)\n    self.assertEqual(\n        5, estimator._load_global_step_from_checkpoint_dir(est.model_dir))\n\n  def test_checkpoint_contains_relative_paths(self):\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(\n        model_dir=tmpdir, model_fn=model_fn_global_step_incrementer)\n    est.train(dummy_input_fn, steps=5)\n\n    checkpoint_file_content = file_io.read_file_to_string(\n        os.path.join(tmpdir, 'checkpoint'))\n    ckpt = checkpoint_state_pb2.CheckpointState()\n    text_format.Merge(checkpoint_file_content, ckpt)\n    self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')\n    # TODO(b/78461127): Please modify tests to not directly rely on names of\n    # checkpoints.\n    self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'],\n                        ckpt.all_model_checkpoint_paths)\n\n  def test_train_save_copy_reload(self):\n    tmpdir = tempfile.mkdtemp()\n    model_dir1 = os.path.join(tmpdir, 'model_dir1')\n    est1 = estimator.EstimatorV2(\n        model_dir=model_dir1, model_fn=model_fn_global_step_incrementer)\n    est1.train(dummy_input_fn, steps=5)\n\n    # We have to clear the cache before we can rename the directory,\n    # otherwise open file handles will prevent the delete on Windows.\n    tf.summary.FileWriterCache.clear()\n    model_dir2 = os.path.join(tmpdir, 'model_dir2')\n    os.renames(model_dir1, model_dir2)\n\n    est2 = estimator.EstimatorV2(\n        model_dir=model_dir2, model_fn=model_fn_global_step_incrementer)\n    self.assertEqual(\n        5, estimator._load_global_step_from_checkpoint_dir(est2.model_dir))\n    est2.train(dummy_input_fn, steps=5)\n    self.assertEqual(\n        10, estimator._load_global_step_from_checkpoint_dir(est2.model_dir))\n\n  def test_steps0_raises_error(self):\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_eval_metric_ops)\n    with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'):\n      est.train(dummy_input_fn, steps=0)\n\n  def test_steps_negative_raises_error(self):\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_eval_metric_ops)\n    with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'):\n      est.train(dummy_input_fn, steps=-1)\n\n  def test_max_steps0_raises_error(self):\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_eval_metric_ops)\n    with self.assertRaisesRegexp(ValueError, 'Must specify max_steps > 0'):\n      est.train(dummy_input_fn, max_steps=0)\n\n  def test_max_steps_negative_raises_error(self):\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_eval_metric_ops)\n    with self.assertRaisesRegexp(ValueError, 'Must specify max_steps > 0'):\n      est.train(dummy_input_fn, max_steps=-1)\n\n  def test_scaffold_is_used(self):\n    self.is_init_fn_called = False\n\n    def _init_fn(scaffold, sess):\n      _, _ = scaffold, sess\n      self.is_init_fn_called = True\n\n    def _model_fn_scaffold(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          scaffold=tf.train.Scaffold(init_fn=_init_fn))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_scaffold)\n    est.train(dummy_input_fn, steps=1)\n    self.assertTrue(self.is_init_fn_called)\n\n  def test_hooks_should_be_session_run_hook(self):\n    est = estimator.EstimatorV2(model_fn=model_fn_global_step_incrementer)\n    with self.assertRaisesRegexp(TypeError, 'must be a SessionRunHook'):\n      est.train(dummy_input_fn, steps=1, hooks=['NotAHook'])\n\n  def test_training_hooks_are_used(self):\n    chief_hook = tf.test.mock.MagicMock(\n        wraps=tf.train.SessionRunHook(), spec=tf.train.SessionRunHook)\n    hook = tf.test.mock.MagicMock(\n        wraps=tf.train.SessionRunHook(), spec=tf.train.SessionRunHook)\n\n    def _model_fn_hooks(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          training_chief_hooks=[chief_hook],\n          training_hooks=[hook])\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_hooks)\n    self.assertFalse(chief_hook.begin.called)\n    self.assertFalse(hook.begin.called)\n    est.train(dummy_input_fn, steps=1)\n    self.assertTrue(chief_hook.begin.called)\n    self.assertTrue(hook.begin.called)\n\n  def test_saving_listeners_are_used(self):\n    listener = tf.test.mock.Mock(spec=tf.train.CheckpointSaverListener)\n    listener.after_save.return_value = None\n    est = estimator.EstimatorV2(\n        model_fn=model_fn_global_step_incrementer,\n        config=run_config.RunConfig(save_checkpoints_steps=10))\n    est.train(dummy_input_fn, steps=26, saving_listeners=[listener])\n    self.assertEqual(4, listener.before_save.call_count)\n    self.assertEqual(4, listener.after_save.call_count)\n\n  def test_saver_hook_should_exist_to_use_saving_listeners(self):\n    listener = tf.test.mock.Mock(spec=tf.train.CheckpointSaverListener)\n    est = estimator.EstimatorV2(\n        model_fn=model_fn_global_step_incrementer,\n        config=run_config.RunConfig(\n            save_checkpoints_steps=None, save_checkpoints_secs=None))\n    with self.assertRaisesRegexp(ValueError,\n                                 'CheckpointSaverHook to use saving_listeners'):\n      est.train(dummy_input_fn, steps=1, saving_listeners=[listener])\n\n  def test_listeners_should_be_listeners(self):\n    est = estimator.EstimatorV2(model_fn=model_fn_global_step_incrementer)\n    with self.assertRaisesRegexp(TypeError,\n                                 'must be a list of CheckpointSaverListener'):\n      est.train(dummy_input_fn, steps=1, saving_listeners=['not-a-listener'])\n\n  def test_chief_only_hook_should_not_be_called_on_non_chief(self):\n    chief_hook = tf.test.mock.MagicMock(\n        wraps=tf.train.SessionRunHook(), spec=tf.train.SessionRunHook)\n    hook = tf.test.mock.MagicMock(\n        wraps=tf.train.SessionRunHook(), spec=tf.train.SessionRunHook)\n\n    def _model_fn_hooks(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          training_chief_hooks=[chief_hook],\n          training_hooks=[hook])\n\n    class NonChiefRunConfig(run_config.RunConfig):\n\n      @property\n      def is_chief(self):  # pylint: disable=g-wrong-blank-lines\n        return False\n\n    # Mocking the SessionManager.wait_for_session, so that worker doesn't wait\n    # for chief.\n    def get_initialized_session(*args, **kwargs):\n      # Session doesn't take 'max_wait_secs' argument.\n      kwargs.pop('max_wait_secs', None)\n      scaffold = tf.train.Scaffold().finalize()\n      sess = tf.Session(*args, **kwargs)\n      sess.run(scaffold.init_op)\n      return sess\n\n    with tf.test.mock.patch.object(\n        tf.train.SessionManager,\n        'wait_for_session',\n        side_effect=get_initialized_session):\n      est = estimator.EstimatorV2(\n          model_fn=_model_fn_hooks, config=NonChiefRunConfig())\n      self.assertFalse(chief_hook.begin.called)\n      self.assertFalse(hook.begin.called)\n      est.train(dummy_input_fn, steps=1)\n      self.assertFalse(chief_hook.begin.called)\n      self.assertTrue(hook.begin.called)\n\n  def test_features_labels_mode(self):\n    given_features = {'test-features': [[1], [1]]}\n    given_labels = {'test-labels': [[1], [1]]}\n\n    def _input_fn():\n      return given_features, given_labels\n\n    def _model_fn(features, labels, mode):\n      self.features, self.labels, self.mode = features, labels, mode\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[0.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(_input_fn, steps=1)\n    self.assertEqual(given_features, self.features)\n    self.assertEqual(given_labels, self.labels)\n    self.assertEqual(ModeKeys.TRAIN, self.mode)\n\n  def test_graph_initialization_global_step_and_random_seed(self):\n    expected_random_seed = run_config.RunConfig().tf_random_seed\n\n    def _model_fn(features, labels, mode):\n      _, _, _ = features, labels, mode\n      self.assertIsNotNone(tf.train.get_global_step())\n      self.assertEqual(expected_random_seed, tf.get_default_graph().seed)\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[0.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n\n  def test_config_should_not_be_evaluator_or_ps(self):\n\n    class FakeEvaluatorConfig(run_config.RunConfig):\n\n      @property\n      def task_type(self):\n        return run_config.TaskType.EVALUATOR\n\n    est = estimator.EstimatorV2(\n        model_fn=dummy_model_fn, config=FakeEvaluatorConfig())\n    with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):\n      est.train(dummy_input_fn, steps=1)\n\n  def test_master_distributed_hooks(self):\n    tf_config = json.dumps({\n        'cluster': {\n            run_config.TaskType.PS: ['localhost:1234'],\n            run_config.TaskType.WORKER: ['localhost:1235'],\n            run_config.TaskType.MASTER: ['localhost:1236']\n        },\n        'task': {\n            'type': run_config.TaskType.MASTER,\n            'index': 0\n        }\n    })\n    with tf.test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):\n      est = estimator.EstimatorV2(\n          model_fn=model_fn_global_step_incrementer,\n          config=run_config.RunConfig())\n\n    with tf.test.mock.patch.object(training,\n                                   'MonitoredTrainingSession') as mock_sess:\n      est.train(dummy_input_fn, steps=1)\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.SummarySaverHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.StepCounterHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])\n      self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])\n\n  def test_master_distributed_hooks_for_worker_0(self):\n    tf_config = json.dumps({\n        'cluster': {\n            run_config.TaskType.PS: ['localhost:1234'],\n            run_config.TaskType.WORKER: ['localhost:1235'],\n            run_config.TaskType.MASTER: ['localhost:1236']\n        },\n        'task': {\n            'type': run_config.TaskType.WORKER,\n            'index': 0\n        }\n    })\n    with tf.test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):\n      est = estimator.EstimatorV2(\n          model_fn=model_fn_global_step_incrementer,\n          config=run_config.RunConfig())\n\n    with tf.test.mock.patch.object(training,\n                                   'MonitoredTrainingSession') as mock_sess:\n      est.train(dummy_input_fn, steps=1)\n      self.assertTrue(\n          any(\n              isinstance(hook, tf.train.SummarySaverHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertTrue(\n          any(\n              isinstance(hook, tf.train.StepCounterHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])\n      self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])\n\n  def test_master_distributed_hooks_for_worker_nonzero(self):\n    tf_config = json.dumps({\n        'cluster': {\n            run_config.TaskType.PS: ['localhost:1234'],\n            run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'],\n            run_config.TaskType.MASTER: ['localhost:1236']\n        },\n        'task': {\n            'type': run_config.TaskType.WORKER,\n            'index': 1\n        }\n    })\n    with tf.test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):\n      est = estimator.EstimatorV2(\n          model_fn=model_fn_global_step_incrementer,\n          config=run_config.RunConfig())\n\n    with tf.test.mock.patch.object(training,\n                                   'MonitoredTrainingSession') as mock_sess:\n      est.train(dummy_input_fn, steps=1)\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.SummarySaverHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.StepCounterHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])\n      self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])\n\n  def test_master_hooks_single_replica(self):\n    tf_config = json.dumps({\n        'cluster': {\n            run_config.TaskType.MASTER: ['localhost:1234']\n        },\n        'task': {\n            'type': run_config.TaskType.MASTER,\n            'index': 0\n        }\n    })\n    with tf.test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):\n      est = estimator.EstimatorV2(\n          model_fn=model_fn_global_step_incrementer,\n          config=run_config.RunConfig(\n              save_summary_steps=100, log_step_count_steps=200))\n\n    with tf.test.mock.patch.object(training,\n                                   'MonitoredTrainingSession') as mock_sess:\n      est.train(dummy_input_fn, steps=1)\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.SummarySaverHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.StepCounterHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])\n      self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])\n\n  def test_master_hooks_single_replica_with_ps(self):\n    tf_config = json.dumps({\n        'cluster': {\n            run_config.TaskType.MASTER: ['localhost:1234'],\n            run_config.TaskType.PS: ['localhost: 1235'],\n        },\n        'task': {\n            'type': run_config.TaskType.MASTER,\n            'index': 0\n        }\n    })\n    with tf.test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):\n      est = estimator.EstimatorV2(\n          model_fn=model_fn_global_step_incrementer,\n          config=run_config.RunConfig(\n              save_summary_steps=100, log_step_count_steps=200))\n\n    with tf.test.mock.patch.object(training,\n                                   'MonitoredTrainingSession') as mock_sess:\n      est.train(dummy_input_fn, steps=1)\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.SummarySaverHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.StepCounterHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertEqual(100, mock_sess.call_args[1]['save_summaries_steps'])\n      self.assertEqual(200, mock_sess.call_args[1]['log_step_count_steps'])\n\n  def test_hooks_with_distributed_collective_ops(self):\n    if tf.executing_eagerly():\n      self.skipTest('n/a: legacy graph only')\n    tf_config = json.dumps({\n        'cluster': {\n            run_config.TaskType.WORKER: ['', ''],\n        },\n        'task': {\n            'type': run_config.TaskType.WORKER,\n            'index': 0\n        }\n    })\n    # We let it skip setting eager context in multi-worker path by creating a\n    # single-worker strategy and then passing cluster info into it.\n    strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()\n    strategy.configure(\n        cluster_spec={\n            run_config.TaskType.WORKER: ['', ''],\n        },\n        task_type=run_config.TaskType.WORKER,\n        task_id=0)\n    with tf.test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):\n      config = run_config.RunConfig(\n          train_distribute=strategy,\n          save_summary_steps=1000,\n          save_checkpoints_steps=500)\n      config._distribute_coordinator_mode = None  # Skip distribute coordintor.\n      est = estimator.EstimatorV2(\n          model_fn=model_fn_global_step_incrementer, config=config)\n\n    def input_fn():\n      return tf.data.Dataset.from_tensors(({\n          'x': tf.constant([[1], [1]])\n      }, tf.constant([[1], [1]])))\n\n    with tf.test.mock.patch.object(training,\n                                   'MonitoredTrainingSession') as mock_sess:\n      est.train(input_fn, steps=1)\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.SummarySaverHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.StepCounterHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertFalse(\n          any(\n              isinstance(hook, tf.train.CheckpointSaverHook)\n              for hook in mock_sess.call_args[1]['hooks']))\n      self.assertEqual(1000, mock_sess.call_args[1]['save_summaries_steps'])\n      self.assertEqual(500, mock_sess.call_args[1]['save_checkpoint_steps'])\n      self.assertEqual(100, mock_sess.call_args[1]['log_step_count_steps'])\n\n\ndef _model_fn_with_eval_metric_ops(features, labels, mode, params):\n  _, _ = features, labels\n  global_step = tf.train.get_global_step()\n  loss = tf.constant(1.)\n  metric_name_1 = params.get('metric_name') or 'metric'\n  metric_value_1 = params.get('metric_value') or 2.\n  metric_name_2 = params.get('metric_name_2') or 'metric2'\n  metric_value_2 = params.get('metric_value_2') or 2.\n\n  metric_update_op = loss.op\n  metric_tensor = control_flow_ops.with_dependencies(\n      [metric_update_op], tf.constant(metric_value_1))\n\n  mean = tf_keras_v1.metrics.Mean()\n  mean.update_state(metric_value_2)\n  return model_fn_lib.EstimatorSpec(\n      mode,\n      loss=loss,\n      predictions={'predictions': tf.constant(1.)},\n      train_op=tf.assign_add(global_step, 1),\n      eval_metric_ops={\n          metric_name_1: (metric_tensor, metric_update_op),\n          metric_name_2: mean,\n      })\n\n\nclass _StepCounterHook(tf.train.SessionRunHook):\n  \"\"\"Hooks that counts the number of times it is called.\"\"\"\n\n  def __init__(self):\n    self._steps = 0\n\n  def before_run(self, run_context):\n    del run_context\n    self._steps += 1\n\n  @property\n  def steps(self):\n    return self._steps\n\n\nclass EstimatorGetVariablesTest(tf.test.TestCase):\n\n  def test_model_should_be_trained(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      tf.Variable(1., name='one')\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    with self.assertRaisesRegexp(ValueError, 'not find trained model'):\n      est.get_variable_names()\n    with self.assertRaisesRegexp(ValueError, 'not find trained model'):\n      est.get_variable_value('one')\n\n  def test_get_variable_utils(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      tf.Variable(1., name='one')\n      tf.Variable(3., name='three')\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(input_fn=dummy_input_fn, steps=1)\n    self.assertEqual(\n        set(['one', 'three', 'global_step']), set(est.get_variable_names()))\n    self.assertEqual(1., est.get_variable_value('one'))\n    self.assertEqual(3., est.get_variable_value('three'))\n\n\nclass EstimatorTraceTest(tf.test.TestCase, parameterized.TestCase):\n\n  def setUp(self):\n    self._profiler_dir = os.path.join(self.get_temp_dir(), 'profiler')\n\n  expected_features = {'x': 42., 'y': 43.}\n  expected_labels = 44.\n  model_fn_call_count = [0]\n  input_fn = _make_input_fn(expected_features, expected_labels)\n\n  class ModelFn(object):\n\n    def __call__(self, features, labels):\n      EstimatorTraceTest.model_fn_call_count[0] += 1\n      return _estimator_spec(EstimatorTraceTest.expected_features,\n                             EstimatorTraceTest.expected_labels, features,\n                             labels, ModeKeys.TRAIN)\n\nclass EstimatorDatasetIntegrationTest(tf.test.TestCase):\n  \"\"\"Tests dataset integration.\"\"\"\n\n  def test_returned_by_input_fn(self):\n\n    def _input_fn():\n      return tf.data.Dataset.from_tensors(([1.], [2.]))\n\n    def _model_fn(features, labels, mode):\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=features + labels,  # 1 + 2\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(_input_fn, steps=1)\n    scores = est.evaluate(_input_fn, steps=1)\n    self.assertEqual(3., scores[model_fn_lib.LOSS_METRIC_KEY])\n\n  def test_with_none_labels(self):\n\n    def _input_fn():\n      return tf.data.Dataset.from_tensors([7.])\n\n    def _model_fn(features, labels, mode):\n      self.assertIsNone(labels)\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=features,  # 7\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(_input_fn, steps=1)\n    scores = est.evaluate(_input_fn, steps=1)\n    self.assertEqual(7., scores[model_fn_lib.LOSS_METRIC_KEY])\n\n  def test_with_predict(self):\n\n    def _input_fn():\n      return tf.data.Dataset.from_tensors([10.])\n\n    def _model_fn(features, labels, mode):\n      _ = labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          predictions=features,  # 10\n          loss=features,  # 10\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(_input_fn, steps=1)\n    self.assertEqual([10.], next(est.predict(input_fn=_input_fn)))\n\n  def test_batching(self):\n\n    def _input_fn():\n      return tf.data.Dataset.from_tensor_slices(\n          ([[1.], [2.]], [[10.], [20.]])).batch(1)\n\n    def _model_fn(features, labels, mode):\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          predictions=features,\n          loss=features + (0 if labels is None else labels),  # 11, 22\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(_input_fn)\n    scores = est.evaluate(_input_fn)\n    # (11 + 22)/2 = 16.5\n    self.assertEqual(16.5, scores[model_fn_lib.LOSS_METRIC_KEY])\n    self.assertEqual([1., 2.], list(est.predict(_input_fn)))\n\n\nclass EstimatorEvaluateTest(tf.test.TestCase):\n\n  def test_eval_dir(self):\n    est = estimator.EstimatorV2(\n        model_fn=model_fn_global_step_incrementer, model_dir='some_path')\n    expected_eval_dir = os.path.join('some_path', 'eval')\n    self.assertEqual(expected_eval_dir, est.eval_dir())\n    expected_eval_dir_name = os.path.join('some_path', 'eval_a_name')\n    self.assertEqual(expected_eval_dir_name, est.eval_dir('a_name'))\n\n  def test_input_fn_args(self):\n    expected_mode = ModeKeys.EVAL\n    expected_params = {'batch_size': 10}\n    expected_config = run_config.RunConfig().replace(tf_random_seed=4321)\n    input_fn_call_count = [0]\n\n    def _model_fn(features, labels, mode, params, config):\n      del params, config\n      return model_fn_global_step_incrementer(features, labels, mode)\n\n    def _input_fn(mode, params, config):\n      input_fn_call_count[0] += 1\n      self.assertEqual(expected_mode, mode)\n      self.assertEqual(expected_params, params)\n      self.assertEqual(4321, config.tf_random_seed)\n      return dummy_input_fn()\n\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn, params=expected_params, config=expected_config)\n    est.train(dummy_input_fn, steps=1)\n    self.assertEqual(0, input_fn_call_count[0])\n    est.evaluate(_input_fn, steps=1)\n    self.assertEqual(1, input_fn_call_count[0])\n\n  def test_model_fn_must_return_estimator_spec(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      if mode == ModeKeys.EVAL:\n        return 'NotGoodNotGood'\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(1.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    with self.assertRaisesRegexp(ValueError,\n                                 'model_fn should return an EstimatorSpec'):\n      est.evaluate(dummy_input_fn, steps=1)\n\n  def test_no_checkpoint_uses_init(self):\n\n    def _model_fn(features, labels, mode, params):\n      del features, labels, params\n      mean = tf_keras_v1.metrics.Mean()\n      mean.update_state(tf.Variable(2.) + 1)\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(1.),\n          eval_metric_ops={\n              'mean1': mean,\n              'mean2': tf.metrics.mean(tf.compat.v1.Variable(2.) + 1)\n          })\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    scores = est.evaluate(dummy_input_fn, steps=1)\n    # Metric value here is set to 1 + the value of the Variable that is newly\n    # initialized (since there is no checkpoint).\n    self.assertEqual(3., scores['mean1'])\n    self.assertEqual(3., scores['mean2'])\n\n  @test_util.run_v1_only('b/119219961')\n  def test_no_checkpoint_uses_init_with_warm_starting(self):\n\n    def _make_model_fn(x):\n\n      def _variable_creating_and_export_model_fn(features, labels, mode):\n        _, _ = features, labels\n        x_var = tf.get_variable('x', initializer=x)\n        global_step = tf.train.get_global_step()\n        mean = tf_keras_v1.metrics.Mean()\n        mean.update_state(x_var + 1)\n        return model_fn_lib.EstimatorSpec(\n            mode,\n            predictions={'y': tf.constant(1.0)},\n            loss=tf.constant(1.),\n            eval_metric_ops={\n                'mean1': mean,\n                'mean2': tf.metrics.mean(x_var + 1)\n            },\n            train_op=tf.assign_add(global_step, 1),\n            export_outputs={\n                'test':\n                    export_lib.ClassificationOutput(\n                        tf.constant([4.2]), tf.constant(['label']))\n            })\n\n      return _variable_creating_and_export_model_fn\n\n    first_est = estimator.EstimatorV2(model_fn=_make_model_fn(42.))\n    first_est.train(dummy_input_fn, steps=10)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n    tmpdir = tempfile.mkdtemp()\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    exported_path = first_est.export_saved_model(export_dir_base,\n                                                 serving_input_receiver_fn)\n\n    # Test that we can pass either warm_start_from as an external checkpoint\n    # or an exported SavedModel.\n    est = estimator.EstimatorV2(\n        model_fn=_make_model_fn(52.), warm_start_from=exported_path)\n    eval_metrics = est.evaluate(dummy_input_fn, steps=1)\n    # Metric value here is set to 1 + the value of the Variable that is\n    # warm-started from the SavedModel of the first model (42.), as opposed to\n    # the initialization in the new model_fn (52.).\n    self.assertEqual(43., eval_metrics['mean1'])\n    self.assertEqual(43., eval_metrics['mean2'])\n\n    est = estimator.EstimatorV2(\n        model_fn=_make_model_fn(62.), warm_start_from=first_est.model_dir)\n    eval_metrics = est.evaluate(dummy_input_fn, steps=1)\n    # Metric value here is set to 1 + the value of the Variable that is\n    # warm-started from a checkpoint of the first model (42.), as opposed to\n    # the initialization in the new model_fn (52.).\n    self.assertEqual(43., eval_metrics['mean1'])\n    self.assertEqual(43., eval_metrics['mean2'])\n\n  def test_scores(self):\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn_with_eval_metric_ops,\n        params={\n            'metric_name': 'metric',\n            'metric_value': 2.,\n            'metric_name_2': 'metric2',\n            'metric_value_2': 3.,\n        })\n    est.train(dummy_input_fn, steps=5)\n    scores = est.evaluate(dummy_input_fn, steps=1)\n    self.assertIn('metric', scores)\n    self.assertAlmostEqual(2., scores['metric'])\n    self.assertIn('metric2', scores)\n    self.assertAlmostEqual(3., scores['metric2'])\n\n  def test_tuple_metrics(self):\n\n    def _model_fn(features, labels, mode):\n      del features  # unused\n      del labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          loss=tf.constant(1.),\n          eval_metric_ops={\n              'nested_metric': (\n                  ((tf.constant(2.), tf.constant(1)),\n                   tf.constant(3., dtype=tf.dtypes.float64)), tf.no_op())\n          })\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    evaluation = est.evaluate(dummy_input_fn, steps=1)\n    ((two_float, one_integer), three_double) = evaluation['nested_metric']\n    self.assertAlmostEqual(2., two_float)\n    self.assertEqual(1, one_integer)\n    self.assertAlmostEqual(3., three_double)\n\n  def test_steps0_raises_error(self):\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_eval_metric_ops)\n    est.train(dummy_input_fn, steps=5)\n    with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'):\n      est.evaluate(dummy_input_fn, steps=0)\n\n  def test_steps_negative_raises_error(self):\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_eval_metric_ops)\n    est.train(dummy_input_fn, steps=5)\n    with self.assertRaisesRegexp(ValueError, 'Must specify steps > 0'):\n      est.evaluate(dummy_input_fn, steps=-1)\n\n  def test_global_step_metric_raises_error(self):\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn_with_eval_metric_ops,\n        params={\n            'metric_name': 'global_step',\n            'metric_value': 2.\n        })\n    est.train(dummy_input_fn, steps=5)\n    with self.assertRaisesRegexp(\n        ValueError, 'Metric with name `global_step` is not allowed'):\n      est.evaluate(dummy_input_fn, steps=1)\n\n  def test_global_step_is_reported(self):\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn_with_eval_metric_ops,\n        params={\n            'metric_name': 'metric',\n            'metric_value': 2.,\n            'metric_name_2': 'metric2',\n            'metric_value_2': 3.,\n        })\n    est.train(dummy_input_fn, steps=5)\n    scores = est.evaluate(dummy_input_fn, steps=1)\n    self.assertIn('global_step', scores)\n    self.assertEqual(5, scores['global_step'])\n\n  def test_loss_metric_is_reported(self):\n\n    def _model_fn_with_incremental_loss(features, labels, mode):\n      _, _ = features, labels\n      local_weight = tf.Variable(\n          0., name='local_weight', collections=[tf.GraphKeys.LOCAL_VARIABLES])\n      # Loss will be 2, 4, 6, ...\n      loss = 2 * tf.assign_add(local_weight, 1.)\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=loss,\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_incremental_loss)\n    est.train(dummy_input_fn, steps=1)\n    scores = est.evaluate(dummy_input_fn, steps=5)\n    self.assertIn(model_fn_lib.LOSS_METRIC_KEY, scores)\n    # Average loss will be (2 + 4 + 6 + 8 + 10)/5=6\n    self.assertAlmostEqual(6., scores[model_fn_lib.LOSS_METRIC_KEY])\n\n  def test_hooks_should_be_session_run_hook(self):\n    est = estimator.EstimatorV2(model_fn=model_fn_global_step_incrementer)\n    est.train(dummy_input_fn, steps=1)\n    with self.assertRaisesRegexp(TypeError, 'must be a SessionRunHook'):\n      est.evaluate(dummy_input_fn, steps=5, hooks=['NotAHook'])\n\n  def test_hooks_are_used(self):\n    step_counter_hook = _StepCounterHook()\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_eval_metric_ops)\n    est.train(dummy_input_fn, steps=1)\n    est.evaluate(dummy_input_fn, steps=5, hooks=[step_counter_hook])\n    self.assertEqual(5, step_counter_hook.steps)\n\n  def test_evaluate_from_checkpoint(self):\n    params = {\n        'metric_name': 'metric',\n        'metric_value': 2.,\n        'metric_name_2': 'metric2',\n        'metric_value_2': 3.,\n    }\n    est1 = estimator.EstimatorV2(\n        model_fn=_model_fn_with_eval_metric_ops, params=params)\n    est1.train(dummy_input_fn, steps=5)\n    est2 = estimator.EstimatorV2(\n        model_fn=_model_fn_with_eval_metric_ops, params=params)\n    scores = est2.evaluate(\n        dummy_input_fn, steps=1, checkpoint_path=est1.latest_checkpoint())\n    self.assertEqual(5, scores['global_step'])\n\n  @test_util.run_v1_only('VariableV1 is only exported in v1')\n  def test_wrong_shape_throws_reasonable_error(self):\n    \"\"\"Make sure we are helpful when model_fns change. See b/110263146.\"\"\"\n\n    def _get_model_fn(val=1):\n\n      def _model_fn(features, labels, mode):\n        del features, labels  # unused\n        tf.Variable(val, name='weight')\n        return model_fn_lib.EstimatorSpec(\n            mode=mode,\n            predictions=tf.constant([[1.]]),\n            loss=tf.constant(0.),\n            train_op=tf.assign_add(tf.train.get_global_step(), 1))\n\n      return _model_fn\n\n    model_fn_1 = _get_model_fn()\n    model_fn_2 = _get_model_fn(val=[1])\n\n    est1 = estimator.EstimatorV2(model_fn=model_fn_1)\n    est1.train(dummy_input_fn, steps=5)\n    est2 = estimator.EstimatorV2(model_fn=model_fn_2, model_dir=est1.model_dir)\n\n    expected_msg = 'Restoring from checkpoint failed.*a mismatch between'\n    with self.assertRaisesRegexp(tf.errors.InvalidArgumentError, expected_msg):\n      est2.train(\n          dummy_input_fn,\n          steps=1,\n      )\n\n  def test_scaffold_is_used(self):\n\n    def _model_fn_scaffold(features, labels, mode):\n      _, _ = features, labels\n      tf.Variable(1., name='weight')\n      self.mock_saver = get_mock_saver()\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          predictions=tf.constant([[1.]]),\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          scaffold=tf.train.Scaffold(saver=self.mock_saver))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_scaffold)\n    est.train(dummy_input_fn, steps=1)\n    est.evaluate(dummy_input_fn, steps=1)\n    self.assertTrue(self.mock_saver.restore.called)\n\n  def test_features_labels_mode(self):\n    given_features = {'test-features': [[1], [1]]}\n    given_labels = {'test-labels': [[1], [1]]}\n\n    def _input_fn():\n      return given_features, given_labels\n\n    def _model_fn(features, labels, mode):\n      self.features, self.labels, self.mode = features, labels, mode\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[0.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(_input_fn, steps=1)\n    est.evaluate(_input_fn, steps=1)\n    self.assertEqual(given_features, self.features)\n    self.assertEqual(given_labels, self.labels)\n    self.assertEqual(ModeKeys.EVAL, self.mode)\n\n  def test_graph_initialization_global_step_and_random_seed(self):\n    expected_random_seed = run_config.RunConfig().tf_random_seed\n\n    def _model_fn(features, labels, mode):\n      _, _, _ = features, labels, mode\n      self.assertIsNotNone(tf.train.get_global_step())\n      self.assertEqual(expected_random_seed, tf.get_default_graph().seed)\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[0.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    est.evaluate(dummy_input_fn, steps=1)\n\n  def test_evaluation_hooks_are_used(self):\n    hook = tf.test.mock.MagicMock(\n        wraps=tf.train.SessionRunHook(), spec=tf.train.SessionRunHook)\n\n    def _model_fn_hooks(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          evaluation_hooks=[hook])\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_hooks)\n    est.train(dummy_input_fn, steps=1)\n    self.assertFalse(hook.begin.called)\n    est.evaluate(dummy_input_fn, steps=1)\n    self.assertTrue(hook.begin.called)\n\n  def test_summary_writing_with_summary_proto(self):\n\n    def model_fn_global_step_incrementer_image(features, labels, mode):\n      _, _ = features, labels\n      global_step = tf.train.get_global_step()\n\n      image = tf.zeros([5, 3, 3, 1])\n      eval_metric_ops = {\n          'foo': (tf.summary.image('image', image,\n                                   max_outputs=3), tf.constant(1))\n      }\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(1.),\n          train_op=tf.assign_add(global_step, 1),\n          eval_metric_ops=eval_metric_ops)\n\n    est = estimator.EstimatorV2(\n        model_fn=model_fn_global_step_incrementer_image,\n        config=run_config.RunConfig(save_summary_steps=1))\n    est.train(dummy_input_fn, steps=200)\n    est.evaluate(\n        input_fn=dummy_input_fn,\n        steps=200,\n    )\n\n    # Make sure nothing is stuck in limbo.\n    tf.summary.FileWriterCache.clear()\n\n    # Get last evaluation Event written.\n    for key in ['foo/0', 'foo/1', 'foo/2']:\n      self.assertTrue(\n          check_eventfile_for_keyword(key, est.eval_dir()),\n          '{} should be part of reported summaries.'.format(key))\n\n    # Verify that evaluated checkpoint path is written to event file.\n    checkpoint_path_tag = 'checkpoint_path'\n    self.assertTrue(\n        check_eventfile_for_keyword(checkpoint_path_tag, est.eval_dir()),\n        '{} should be part of reported summaries.'.format(checkpoint_path_tag))\n\n    expected_tensor_proto = tf.make_tensor_proto(\n        est.latest_checkpoint(), dtype=tf.dtypes.string)\n    summaries = summaries_with_matching_keyword(checkpoint_path_tag,\n                                                est.eval_dir())\n    self.assertProtoEquals(expected_tensor_proto,\n                           next(summaries).value[0].tensor)\n\n  def test_summary_writing_with_tensor(self):\n\n    def model_fn_with_prediction_mean_tensor_eval_metric_ops(\n        features, labels, mode, params):\n      _, _ = features, labels\n      global_step = tf.train.get_global_step()\n\n      metric_name = params.get('metric_name') or 'metric'\n      predictions = tf.constant([1., .5, 0.])\n      eval_metric_ops = {metric_name: tf.metrics.mean_tensor(predictions)}\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(1.),\n          predictions={'predictions': predictions},\n          train_op=tf.assign_add(global_step, 1),\n          eval_metric_ops=eval_metric_ops)\n\n    metric_key = 'PMT'\n    params = {\n        'metric_name': metric_key,\n    }\n    est = estimator.EstimatorV2(\n        model_fn=model_fn_with_prediction_mean_tensor_eval_metric_ops,\n        params=params,\n        config=run_config.RunConfig(save_summary_steps=1))\n    est.train(input_fn=dummy_input_fn, steps=10)\n    est.evaluate(\n        input_fn=dummy_input_fn,\n        steps=10,\n    )\n\n    tf.summary.FileWriterCache.clear()\n\n    self.assertTrue(\n        check_eventfile_for_keyword(metric_key, est.eval_dir()),\n        '{} should be part of reported summaries.'.format(metric_key))\n\n    summaries = summaries_with_matching_keyword(metric_key, est.eval_dir())\n    for value in next(summaries).value:\n      if value.tag == metric_key:\n        self.assertTrue(value.HasField('tensor'))\n\n\nclass EstimatorPredictTest(tf.test.TestCase):\n\n  def test_input_fn_args(self):\n    expected_mode = ModeKeys.PREDICT\n    expected_params = {'batch_size': 10}\n    expected_config = run_config.RunConfig().replace(tf_random_seed=4321)\n    input_fn_call_count = [0]\n\n    def _model_fn(features, labels, mode, params, config):\n      del features, labels, params, config\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.]]))\n\n    def _input_fn(mode, params, config):\n      input_fn_call_count[0] += 1\n      self.assertEqual(expected_mode, mode)\n      self.assertEqual(expected_params, params)\n      self.assertEqual(4321, config.tf_random_seed)\n      return dummy_input_fn()\n\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn, params=expected_params, config=expected_config)\n    est.train(dummy_input_fn, steps=1)\n    self.assertEqual(0, input_fn_call_count[0])\n    next(est.predict(_input_fn))\n    self.assertEqual(1, input_fn_call_count[0])\n\n  def test_no_checkpoint_uses_init(self):\n\n    def _model_fn(features, labels, mode, params, config):\n      del features, labels, params, config\n      x = tf.Variable([[3.]], name='x')\n      return model_fn_lib.EstimatorSpec(mode, predictions=tf.math.add(x, 1.))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    # Expected prediction value is 1 + the value of the Variable that is newly\n    # initialized (since there is no checkpoint).\n    self.assertEqual(4., next(est.predict(dummy_input_fn)))\n\n  @test_util.run_v1_only('b/119219961')\n  def test_no_checkpoint_uses_init_with_warm_starting(self):\n\n    def _make_model_fn(x):\n\n      def _variable_creating_and_export_model_fn(features, labels, mode):\n        _, _ = features, labels\n        x_var = tf.Variable([[x]], name='x')\n        return model_fn_lib.EstimatorSpec(\n            mode,\n            predictions=tf.math.add(x_var, 1.),\n            loss=tf.constant(1.),\n            train_op=tf.assign_add(tf.train.get_global_step(), 1),\n            export_outputs={\n                'test':\n                    export_lib.ClassificationOutput(\n                        tf.constant([4.2]), tf.constant(['label']))\n            })\n\n      return _variable_creating_and_export_model_fn\n\n    first_est = estimator.EstimatorV2(model_fn=_make_model_fn(3.))\n    first_est.train(dummy_input_fn, steps=10)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n    tmpdir = tempfile.mkdtemp()\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    exported_path = first_est.export_saved_model(export_dir_base,\n                                                 serving_input_receiver_fn)\n\n    # Test that we can pass either warm_start_from as an external checkpoint\n    # or an exported SavedModel.\n    est = estimator.EstimatorV2(\n        model_fn=_make_model_fn(30.), warm_start_from=exported_path)\n    # Prediction here is set to 1 + the value of the Variable that is\n    # warm-started from the SavedModel of the first model (3.), as opposed to\n    # the initialization in the new model_fn (30.).\n    self.assertEqual(4., next(est.predict(dummy_input_fn)))\n\n    est = estimator.EstimatorV2(\n        model_fn=_make_model_fn(40.), warm_start_from=first_est.model_dir)\n    # Prediction here is set to 1 + the value of the Variable that is\n    # warm-started from a checkpoint of the first model (3.), as opposed to\n    # the initialization in the new model_fn (40.).\n    self.assertEqual(4., next(est.predict(dummy_input_fn)))\n\n  def test_no_trained_model_invalid_checkpoint_path(self):\n    est = estimator.EstimatorV2(model_fn=model_fn_global_step_incrementer)\n    with self.assertRaises(ValueError):\n      next(\n          est.predict(\n              dummy_input_fn,\n              checkpoint_path=tf.train.latest_checkpoint('fakedir')))\n\n  def test_tensor_predictions(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    self.assertEqual(10., next(est.predict(dummy_input_fn)))\n\n  def test_predictionhooks_are_used(self):\n    hook = tf.test.mock.MagicMock(\n        wraps=tf.train.SessionRunHook(), spec=tf.train.SessionRunHook)\n\n    def _model_fn_hooks(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.]]),\n          prediction_hooks=[hook])\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_hooks)\n    est.train(dummy_input_fn, steps=1)\n    self.assertFalse(hook.begin.called)\n    next(est.predict(dummy_input_fn))\n    self.assertTrue(hook.begin.called)\n\n  def test_warn_if_no_queue_runner(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    with tf.test.mock.patch.object(tf.logging, 'warning') as mock_log:\n      next(est.predict(dummy_input_fn))\n      self.assertRegexpMatches(\n          str(mock_log.call_args),\n          'Input graph does not.*contain a QueueRunner.')\n\n  def test_skip_warn_if_dataset_returns_features(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.]]))\n\n    def _input_fn():\n      dataset = tf.data.Dataset.from_tensors([1])\n      iterator = tf.data.make_one_shot_iterator(dataset)\n      return iterator.get_next()\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    with tf.test.mock.patch.object(tf.logging, 'warning') as mock_log:\n      next(est.predict(_input_fn))\n      # The warning should not have keyword QueueRunner.\n      self.assertRegexpMatches(str(mock_log.call_args), '^((?!QueueRunner).)*$')\n\n  def test_skip_warn_if_dataset_returns_features_dict(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.]]))\n\n    def _input_fn():\n      dataset = tf.data.Dataset.from_tensors([1])\n      iterator = tf.data.make_one_shot_iterator(dataset)\n      features = {'age': iterator.get_next()}\n      return features\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    with tf.test.mock.patch.object(tf.logging, 'warning') as mock_log:\n      next(est.predict(_input_fn))\n      # The warning should not have keyword QueueRunner.\n      self.assertRegexpMatches(str(mock_log.call_args), '^((?!QueueRunner).)*$')\n\n  def test_input_fn_can_return_just_features(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n\n    def _only_features():\n      return {'x': tf.constant([[0.]])}\n\n    self.assertEqual([10.], next(est.predict(_only_features)))\n\n  def test_batch_size_mismatch(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions={\n              'y1': tf.constant([[10.]]),\n              'y2': tf.constant([[12.], [13]])\n          })\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    with self.assertRaisesRegexp(ValueError,\n                                 'Batch length of predictions should be same'):\n      next(est.predict(dummy_input_fn))\n\n  def test_iterate_batches(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions={\n              # First dim is different but the prediction should still work\n              'y1': tf.zeros(shape=[3]),\n              'y2': tf.zeros(shape=[5, 3])\n          })\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n\n    predictions = next(est.predict(dummy_input_fn, yield_single_examples=False))\n    self.assertAllEqual(predictions['y1'].shape, [3])\n    self.assertAllEqual(predictions['y2'].shape, [5, 3])\n\n  def test_predict_keys_defined_for_tensor(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    with self.assertRaisesRegexp(\n        ValueError,\n        'predict_keys argument is not valid in case of non-dict predictions'):\n      next(est.predict(dummy_input_fn, predict_keys=['y']))\n\n  def test_predict_keys_does_not_exists(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions={\n              'y1': tf.constant([[10.]]),\n              'y2': tf.constant([[12.]])\n          })\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    with self.assertRaisesRegexp(ValueError,\n                                 'Expected to run at least one output from'):\n      next(est.predict(dummy_input_fn, predict_keys=['y3']))\n\n  def test_return_given_predict_keys(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions={\n              'y1': tf.constant([[10.]]),\n              'y2': tf.constant([[12.]])\n          })\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    results = next(est.predict(dummy_input_fn, predict_keys=['y1']))\n    self.assertIn('y1', results)\n    self.assertNotIn('y2', results)\n\n  def test_yield_rows_of_tensor(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.], [12.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    results = est.predict(dummy_input_fn)\n    self.assertEqual([10.], next(results))\n    self.assertEqual([12.], next(results))\n\n  def test_yield_rows_of_dict(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions={\n              'y1': tf.constant([[10.], [12]]),\n              'y2': tf.constant([[0.], [2.]])\n          })\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    results = est.predict(dummy_input_fn)\n    self.assertDictEqual({'y1': [10.], 'y2': [0.]}, next(results))\n    self.assertDictEqual({'y1': [12.], 'y2': [2.]}, next(results))\n\n  def test_hooks_should_be_session_run_hook(self):\n    est = estimator.EstimatorV2(model_fn=model_fn_global_step_incrementer)\n    est.train(dummy_input_fn, steps=1)\n    with self.assertRaisesRegexp(TypeError, 'must be a SessionRunHook'):\n      next(est.predict(dummy_input_fn, hooks=['NotAHook']))\n\n  def test_hooks_are_used(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[10.], [12.]]))\n\n    step_counter_hook = _StepCounterHook()\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    results = est.predict(dummy_input_fn, hooks=[step_counter_hook])\n    self.assertEqual(0, step_counter_hook.steps)  # not called yet\n    next(results)\n    self.assertEqual(1, step_counter_hook.steps)  # first call\n    next(results)\n    self.assertEqual(1, step_counter_hook.steps)  # it's in same batch\n    next(results)\n    self.assertEqual(2, step_counter_hook.steps)  # next batch\n\n  def test_predict_from_old_model_dir(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      v = tf.Variable([[16.]], name='weight')\n      prediction = v * 2\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=prediction)\n\n    est1 = estimator.EstimatorV2(model_fn=_model_fn)\n    est1.train(dummy_input_fn, steps=1)\n    est2 = estimator.EstimatorV2(model_fn=_model_fn, model_dir=est1.model_dir)\n    self.assertEqual([32.], next(est2.predict(dummy_input_fn)))\n\n  def test_predict_from_checkpoint_path(self):\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      v = tf.Variable([[16.]], name='weight')\n      prediction = v * 2\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=prediction)\n\n    est1 = estimator.EstimatorV2(model_fn=_model_fn)\n    est1.train(dummy_input_fn, steps=1)\n    est2 = estimator.EstimatorV2(model_fn=_model_fn, model_dir=est1.model_dir)\n    self.assertEqual([32.],\n                     next(\n                         est2.predict(\n                             dummy_input_fn,\n                             checkpoint_path=est2.latest_checkpoint())))\n\n  def test_scaffold_is_used(self):\n\n    def _model_fn_scaffold(features, labels, mode):\n      _, _ = features, labels\n      tf.Variable(1., name='weight')\n      self.mock_saver = get_mock_saver()\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          predictions=tf.constant([[1.]]),\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          scaffold=tf.train.Scaffold(saver=self.mock_saver))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_scaffold)\n    est.train(dummy_input_fn, steps=1)\n    next(est.predict(dummy_input_fn))\n    self.assertTrue(self.mock_saver.restore.called)\n\n  def test_features_labels_mode(self):\n    given_features = {'test-features': [[1], [1]]}\n    given_labels = {'test-labels': [[1], [1]]}\n\n    def _input_fn():\n      return given_features, given_labels\n\n    def _model_fn(features, labels, mode):\n      self.features, self.labels, self.mode = features, labels, mode\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[0.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(_input_fn, steps=1)\n    next(est.predict(_input_fn))\n    self.assertEqual(given_features, self.features)\n    self.assertIsNone(self.labels)\n    self.assertEqual(ModeKeys.PREDICT, self.mode)\n\n  def test_graph_initialization_global_step_and_random_seed(self):\n    expected_random_seed = run_config.RunConfig().tf_random_seed\n\n    def _model_fn(features, labels, mode):\n      _, _, _ = features, labels, mode\n      self.assertIsNotNone(tf.train.get_global_step())\n      self.assertEqual(expected_random_seed, tf.get_default_graph().seed)\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[0.]]))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    next(est.predict(dummy_input_fn))\n\n\ndef _model_fn_for_export_tests(features, labels, mode):\n  _, _ = features, labels\n  tf.Variable(1., name='weight')\n  scores = tf.constant([3.])\n  classes = tf.constant(['wumpus'])\n  update_global_step = tf.assign_add(tf.train.get_global_step(), 1)\n  with tf.control_dependencies([update_global_step]):\n    train_op = tf.constant(2.)\n  return model_fn_lib.EstimatorSpec(\n      mode,\n      predictions=tf.constant(10.),\n      loss=tf.constant(1.),\n      train_op=train_op,\n      export_outputs={'test': export_lib.ClassificationOutput(scores, classes)})\n\n\ndef _x_y_input_fn():\n  return ({\n      'x': tf.constant([[1], [1]], name='feature_x'),\n      'y': tf.constant([[2], [2]], name='feature_y')\n  }, tf.constant([[1], [1]], name='truth'))\n\n\ndef _model_fn_with_x_y(features, labels, mode):\n  _ = labels\n  tf.Variable(1., name='weight')\n  scores = tf.constant([3.])\n  classes = tf.constant(['wumpus'])\n  if mode == ModeKeys.PREDICT:\n    tf.Variable(36., name='name_collision')\n    return model_fn_lib.EstimatorSpec(\n        mode,\n        predictions=tf.constant(10.),\n        export_outputs={\n            'test': export_lib.ClassificationOutput(scores, classes)\n        })\n  else:\n    prefix = 'eval_' if mode == ModeKeys.EVAL else ''\n\n    multiplied = tf.math.multiply(\n        features['x'], features['y'], name='{}multiplied'.format(prefix))\n    mean = tf_keras_v1.metrics.Mean(name='{}mean'.format(prefix))\n    mean.update_state(features['x'] - features['y'])\n    eval_metrics = {\n        'mean1':\n            mean,\n        'mean2':\n            tf.metrics.mean(\n                features['x'] - features['y'], name='{}mean'.format(prefix))\n    }\n    tf.Variable(1., name='later_var')\n    tf.Variable(3., name='name_collision')\n    return model_fn_lib.EstimatorSpec(\n        mode,\n        predictions=multiplied,\n        loss=tf.constant(1.),\n        train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n        eval_metric_ops=eval_metrics)\n\n\ndef _model_fn_with_saveables_for_export_tests(features, labels, mode):\n  _, _ = features, labels\n  table = saver_test_utils.CheckpointedOp(name='v2')\n  update_global_step = tf.assign_add(tf.train.get_global_step(), 1)\n  with tf.control_dependencies([update_global_step]):\n    train_op = table.insert('k1', 30.0)\n  prediction = table.lookup('k1', 0.0)\n  return model_fn_lib.EstimatorSpec(\n      mode,\n      predictions=prediction,\n      loss=tf.constant(1.),\n      train_op=train_op,\n      export_outputs={\n          'test': export_lib.PredictOutput({'prediction': prediction})\n      })\n\n\ndef _get_serving_input_receiver_fn():\n  feature_spec = {\n      'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n      'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n  }\n  return export_lib.build_parsing_serving_input_receiver_fn(feature_spec)\n\n\ndef _get_supervised_input_receiver_fn():\n  return export_lib.build_supervised_input_receiver_fn_from_input_fn(\n      _x_y_input_fn)\n\n\n_VOCAB_FILE_CONTENT = 'emerson\\nlake\\npalmer\\n'\n_EXTRA_FILE_CONTENT = 'kermit\\npiggy\\nralph\\n'\n\n\n@test_util.run_v1_only('b/119219961')\nclass EstimatorExportTest(tf.test.TestCase):\n\n  def test_export_saved_model_proto_roundtrip_raw_receiver(self):\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(model_fn=_model_fn_for_export_tests)\n    est.train(input_fn=dummy_input_fn, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    serving_input_receiver_fn = _get_serving_input_receiver_fn()\n    export_dir = est.export_saved_model(export_dir_base,\n                                        serving_input_receiver_fn)\n\n    # Check that all the files are in the right places.\n    self.assertTrue(tf.gfile.Exists(export_dir_base))\n    self._validate_exported_files(export_dir)\n\n    # Restore, to validate that the export was well-formed.\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.SERVING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('input_example_tensor' in graph_ops)\n        self.assertTrue('ParseExample/ParseExampleV2' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n  def test_export_saved_model_train(self):\n    self._test_export_saved_model_for_mode(_get_supervised_input_receiver_fn(),\n                                           ModeKeys.TRAIN)\n\n  def test_export_saved_model_eval(self):\n    self._test_export_saved_model_for_mode(_get_supervised_input_receiver_fn(),\n                                           ModeKeys.EVAL)\n\n  def test_export_saved_model_predict(self):\n    self._test_export_saved_model_for_mode(_get_serving_input_receiver_fn(),\n                                           ModeKeys.PREDICT)\n\n  def _test_export_saved_model_for_mode(self, input_receiver_fn, mode):\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(model_fn=_model_fn_for_export_tests)\n    est.train(input_fn=_x_y_input_fn, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.export_saved_model(\n        export_dir_base, input_receiver_fn, experimental_mode=mode)\n\n    # Check that all the files are in the right places.\n    self.assertTrue(tf.gfile.Exists(export_dir_base))\n    self._validate_exported_files(export_dir)\n\n    # Restore, to validate that the export was well-formed.\n    tag_set = export_lib.EXPORT_TAG_MAP[mode]\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, tag_set, export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertFalse('name_collision_1' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_all_saved_models_proto_roundtrip_receiver_map(self):\n    input_receiver_fn_map = {ModeKeys.PREDICT: _get_serving_input_receiver_fn()}\n    export_dir, tmpdir = self._test_export_all_saved_models(\n        input_receiver_fn_map)\n\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.SERVING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('input_example_tensor' in graph_ops)\n        self.assertTrue('ParseExample/ParseExampleV2' in graph_ops)\n        self.assertFalse('feature_x' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_all_saved_models_proto_roundtrip_train_only(self):\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),\n    }\n    export_dir, tmpdir = self._test_export_all_saved_models(\n        input_receiver_fn_map)\n\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.TRAINING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('multiplied' in graph_ops)\n        self.assertTrue('mean/update_op' in graph_ops)\n        self.assertFalse('eval_multiplied' in graph_ops)\n        self.assertTrue('feature_x' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_all_saved_models_proto_roundtrip_eval_only(self):\n    input_receiver_fn_map = {ModeKeys.EVAL: _get_supervised_input_receiver_fn()}\n    export_dir, tmpdir = self._test_export_all_saved_models(\n        input_receiver_fn_map)\n\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tag_constants.EVAL], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('eval_multiplied' in graph_ops)\n        self.assertTrue('eval_mean/value' in graph_ops)\n        self.assertFalse('multiplied' in graph_ops)\n        self.assertTrue('feature_x' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_all_saved_models_proto_roundtrip_no_serving(self):\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),\n        ModeKeys.EVAL: _get_supervised_input_receiver_fn()\n    }\n    export_dir, tmpdir = self._test_export_all_saved_models(\n        input_receiver_fn_map)\n\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.TRAINING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('multiplied' in graph_ops)\n        self.assertFalse('eval_multiplied' in graph_ops)\n        self.assertTrue('feature_x' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tag_constants.EVAL], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('eval_multiplied' in graph_ops)\n        self.assertFalse('multiplied' in graph_ops)\n        self.assertTrue('feature_x' in graph_ops)\n        self.assertTrue('feature_y' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_all_saved_models_proto_roundtrip_three_defs(self):\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),\n        ModeKeys.EVAL: _get_supervised_input_receiver_fn(),\n        ModeKeys.PREDICT: _get_serving_input_receiver_fn()\n    }\n    export_dir, tmpdir = self._test_export_all_saved_models(\n        input_receiver_fn_map)\n\n    # Restore, to validate that the export was well-formed.\n    for tag_set in export_lib.EXPORT_TAG_MAP.values():\n      with tf.Graph().as_default() as graph:\n        with tf.Session(graph=graph) as sess:\n          tf.saved_model.load(sess, tag_set, export_dir)\n          graph_ops = [x.name for x in graph.get_operations()]\n          self.assertTrue('global_step/Assign' in graph_ops)\n          self.assertTrue('global_step/Initializer/zeros' in graph_ops)\n          self.assertTrue('weight' in graph_ops)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_all_saved_models_proto_roundtrip_all_vars(self):\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),\n        ModeKeys.PREDICT: _get_serving_input_receiver_fn()\n    }\n    export_dir, tmpdir = self._test_export_all_saved_models(\n        input_receiver_fn_map)\n\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.TRAINING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('later_var' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.SERVING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertFalse('later_var' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_all_saved_models_name_collision(self):\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),\n        ModeKeys.PREDICT: _get_serving_input_receiver_fn()\n    }\n    export_dir, tmpdir = self._test_export_all_saved_models(\n        input_receiver_fn_map)\n\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.TRAINING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('name_collision' in graph_ops)\n        self.assertFalse('name_collision_1' in graph_ops)\n        collection_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n        self.assertEqual(3, collection_vars[-1].eval())\n\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.SERVING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('name_collision' in graph_ops)\n        self.assertFalse('name_collision_1' in graph_ops)\n        collection_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)\n        # This is a non-obvious detail: when we load the estimator spec\n        # for predict, name_collision gets set to 36. However, we then restore\n        # from checkpoint, which should overwrite that var and make it the 3\n        # from training. In practice, this would not be a good way to write\n        # a model_fn, but leaving this check in for now to ensure consistency\n        # with what would happen given our current order of spec, then\n        # checkpoint.\n        self.assertEqual(3, collection_vars[-1].eval())\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def _test_export_all_saved_models(self, input_receiver_fn_map):\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_x_y)\n    est.train(input_fn=_x_y_input_fn, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.experimental_export_all_saved_models(\n        export_dir_base, input_receiver_fn_map)\n\n    # Check that all the files are in the right places.\n    self.assertTrue(tf.gfile.Exists(export_dir_base))\n\n    self._validate_exported_files(export_dir)\n\n    return export_dir, tmpdir\n\n  def _validate_exported_files(self, export_dir):\n    self.assertTrue(tf.gfile.Exists(export_dir))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('saved_model.pb'))))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('variables'))))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('variables/variables.index'))))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('variables/variables.data-00000-of-00001'))))\n\n  def test_export_all_saved_models_var_not_found(self):\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),\n        ModeKeys.EVAL: _get_supervised_input_receiver_fn(),\n        ModeKeys.PREDICT: _get_serving_input_receiver_fn()\n    }\n\n    def _model_fn_with_predict_only_vars(features, labels, mode):\n      _, _ = features, labels\n      if mode == ModeKeys.PREDICT:\n        tf.Variable(1., name='only_in_predict')\n      else:\n        tf.Variable(1., name='otherwise')\n\n      prediction = tf.constant(1.)\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          predictions=prediction,\n          loss=tf.constant(1.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          export_outputs={\n              'test': export_lib.PredictOutput({'prediction': prediction})\n          })\n\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_predict_only_vars)\n    est.train(input_fn=_x_y_input_fn, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n\n    err_regex = r'Could not load all requested variables[\\w\\W]*infer'\n    with self.assertRaisesRegexp(ValueError, err_regex):\n      est.experimental_export_all_saved_models(export_dir_base,\n                                               input_receiver_fn_map)\n\n  def test_export_all_saved_models_metric_operation(self):\n    \"\"\"Ensures metrics ops.Operations can be expoerted (b/109740581).\"\"\"\n\n    def _model_fn(features, labels, mode):\n      del features, labels  # Unused\n      metric_obj = tf_keras_v1.metrics.Mean()\n      metric_obj.update_state(tf.constant([0]))\n      eval_metrics = {\n          'metrics1': (tf.constant([0]), tf.no_op()),\n          'metrics2': metric_obj,\n      }\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          predictions=tf.constant(10.),\n          loss=tf.constant(1.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          eval_metric_ops=eval_metrics)\n\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(input_fn=dummy_input_fn, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir),\n        tf.compat.as_bytes('metric_operation_export'))\n\n    input_receiver_fn_map = {ModeKeys.EVAL: _get_supervised_input_receiver_fn()}\n\n    export_dir = est.experimental_export_all_saved_models(\n        export_dir_base, input_receiver_fn_map)\n\n    # Restore, to validate that the export was well-formed.\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        meta_graph = tf.saved_model.load(sess, [tag_constants.EVAL], export_dir)\n        sig_outputs = meta_graph.signature_def[ModeKeys.EVAL].outputs\n        self.assertTrue(sig_outputs['metrics1/update_op'].name.startswith(\n            'metric_op_wrapper'))\n        self.assertTrue(sig_outputs['metrics2/update_op'].name.startswith(\n            'metric_op_wrapper'))\n\n  def test_export_saved_model_with_saveables_proto_roundtrip(self):\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(\n        model_fn=_model_fn_with_saveables_for_export_tests)\n    est.train(input_fn=dummy_input_fn, steps=1)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        serving_input_receiver_fn)\n\n    # Check that all the files are in the right places.\n    self.assertTrue(tf.gfile.Exists(export_dir_base))\n    self.assertTrue(tf.gfile.Exists(export_dir))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('saved_model.pb'))))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('variables'))))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('variables/variables.index'))))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('variables/variables.data-00000-of-00001'))))\n\n    # Restore, to validate that the export was well-formed.\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.SERVING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('input_example_tensor' in graph_ops)\n        self.assertTrue('ParseExample/ParseExampleV2' in graph_ops)\n        # The original saver is used to restore variables\n        self.assertTrue('save/LookupTableImportV2' in graph_ops)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_saved_model_assets(self):\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(model_fn=_model_fn_for_export_tests)\n    est.train(input_fn=dummy_input_fn, steps=1)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n\n    # Create a fake asset.\n    vocab_file_name = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('my_vocab_file'))\n    vocab_file = tf.io.gfile.GFile(vocab_file_name, mode='w')\n    vocab_file.write(_VOCAB_FILE_CONTENT)\n    vocab_file.close()\n\n    # hack in an op that uses the asset, in order to test asset export.\n    # this is not actually valid, of course.\n    def serving_input_receiver_with_asset_fn():\n      features, receiver_tensor, _ = serving_input_receiver_fn()\n      filename = ops.convert_to_tensor(\n          vocab_file_name, tf.dtypes.string, name='asset_filepath')\n      tf.add_to_collection(tf.compat.v1.GraphKeys.ASSET_FILEPATHS, filename)\n      features['bogus_filename'] = filename\n\n      return export_lib.ServingInputReceiver(features, receiver_tensor)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        serving_input_receiver_with_asset_fn)\n\n    # Check that the asset files are in the right places.\n    expected_vocab_file_name = os.path.join(\n        tf.compat.as_bytes(export_dir),\n        tf.compat.as_bytes('assets/my_vocab_file'))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir), tf.compat.as_bytes('assets'))))\n    self.assertTrue(tf.gfile.Exists(expected_vocab_file_name))\n    self.assertEqual(\n        tf.compat.as_bytes(_VOCAB_FILE_CONTENT),\n        tf.compat.as_bytes(tf.io.gfile.GFile(expected_vocab_file_name).read()))\n\n    # Restore, to validate that the export was well-formed.\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.SERVING], export_dir)\n        assets = [\n            x.eval() for x in graph.get_collection(tf.GraphKeys.ASSET_FILEPATHS)\n        ]\n        self.assertItemsEqual([vocab_file_name], assets)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('input_example_tensor' in graph_ops)\n        self.assertTrue('ParseExample/ParseExampleV2' in graph_ops)\n        self.assertTrue('asset_filepath' in graph_ops)\n        self.assertTrue('weight' in graph_ops)\n\n    # cleanup\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_saved_model_extra_assets(self):\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(model_fn=_model_fn_for_export_tests)\n    est.train(input_fn=dummy_input_fn, steps=1)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n\n    # Create a fake asset.\n    extra_file_name = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('my_extra_file'))\n    extra_file = tf.io.gfile.GFile(extra_file_name, mode='w')\n    extra_file.write(_EXTRA_FILE_CONTENT)\n    extra_file.close()\n\n    # Perform the export.\n    assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.export_saved_model(\n        export_dir_base, serving_input_receiver_fn, assets_extra=assets_extra)\n\n    # Check that the asset files are in the right places.\n    expected_extra_path = os.path.join(\n        tf.compat.as_bytes(export_dir),\n        tf.compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('assets.extra'))))\n    self.assertTrue(tf.gfile.Exists(expected_extra_path))\n    self.assertEqual(\n        tf.compat.as_bytes(_EXTRA_FILE_CONTENT),\n        tf.compat.as_bytes(tf.io.gfile.GFile(expected_extra_path).read()))\n\n    # cleanup\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_saved_model_tensor_features(self):\n    \"\"\"Test that models accepting a single raw Tensor can be exported.\n\n    See https://github.com/tensorflow/tensorflow/issues/11674\n\n    If the model_fn and receiver_fn accept raw tensors rather than dictionaries\n    as input, export_saved_model should be okay with that, too.\n\n    \"\"\"\n\n    tmpdir = tempfile.mkdtemp()\n\n    def _input_fn_tensor_features():\n      t = tf.constant([1, 2, 3], dtype=tf.dtypes.float32, shape=[1, 3])\n      return (t, None)\n\n    def _model_fn_tensor_features(features, labels, mode):\n      _ = labels\n      prediction = tf.linalg.matmul(features, features, transpose_b=True)\n\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          predictions=prediction,\n          loss=tf.constant(1.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          export_outputs={\n              'test': export_lib.PredictOutput({'prediction': prediction})\n          })\n\n    def _serving_input_receiver_fn():\n      feat = tf.placeholder(dtype=tf.dtypes.float32)\n      return export_lib.TensorServingInputReceiver(\n          features=feat, receiver_tensors=feat)\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_tensor_features)\n    est.train(input_fn=_input_fn_tensor_features, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        _serving_input_receiver_fn)\n\n    # Restore, to validate that the export was well-formed.\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.SERVING], export_dir)\n        graph_ops = [x.name.lower() for x in graph.get_operations()]\n        self.assertTrue('const' in graph_ops)\n        self.assertTrue('matmul' in graph_ops)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_saved_model_int_feature_keys(self):\n    \"\"\"Test that the `features` dict can contain int keys.\"\"\"\n    tmpdir = tempfile.mkdtemp()\n\n    def _input_fn_with_int_keys():\n      features = {\n          'string_key': tf.constant([1], dtype=tf.dtypes.float32),\n          42: tf.constant([43], dtype=tf.dtypes.float32),\n      }\n      return (features, None)\n\n    def _model_fn_with_int_keys(features, labels, mode):\n      _ = labels\n      prediction = tf.math.maximum(features['string_key'], features[42])\n\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          predictions=prediction,\n          loss=tf.constant(1.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          export_outputs={\n              'test': export_lib.PredictOutput({'prediction': prediction})\n          })\n\n    def _serving_input_receiver_fn():\n      features = {\n          'string_key': tf.placeholder(dtype=tf.dtypes.float32),\n          42: tf.placeholder(dtype=tf.dtypes.float32, name='42_placeholder'),\n      }\n      # int is only allowed in the `features` dict, not the `receiver_tensors`.\n      receiver_tensors = {\n          'string_key': features['string_key'],\n          '42_key': features[42],\n      }\n      return export_lib.ServingInputReceiver(\n          features=features, receiver_tensors=receiver_tensors)\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_with_int_keys)\n    est.train(input_fn=_input_fn_with_int_keys, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        _serving_input_receiver_fn)\n\n    # Restore, to validate that the export was well-formed.\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        meta_graph_def = tf.saved_model.load(sess, [tf.saved_model.SERVING],\n                                             export_dir)\n        graph_ops = [x.name.lower() for x in graph.get_operations()]\n        self.assertTrue('maximum' in graph_ops)\n        self.assertTrue('42_placeholder' in graph_ops)\n        self.assertTrue(\n            '42_key' in meta_graph_def.signature_def['serving_default'].inputs)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_scaffold_is_used_for_saver(self):\n    tmpdir = tempfile.mkdtemp()\n\n    def _model_fn_scaffold(features, labels, mode):\n      _, _ = features, labels\n      tf.Variable(1., name='weight')\n      self.mock_saver = get_mock_saver()\n      scores = tf.constant([3.])\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          predictions=tf.constant([[1.]]),\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          scaffold=tf.train.Scaffold(saver=self.mock_saver),\n          export_outputs={'test': export_lib.ClassificationOutput(scores)})\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_scaffold)\n    est.train(dummy_input_fn, steps=1)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    est.export_saved_model(export_dir_base, serving_input_receiver_fn)\n\n    self.assertTrue(self.mock_saver.restore.called)\n    self.assertTrue(self.mock_saver.export_meta_graph.called)\n    self.assertTrue(self.mock_saver.save.called)\n\n  def test_scaffold_is_used_for_saver_multiple_modes(self):\n    tmpdir = tempfile.mkdtemp()\n    savers = {'predict_saver': None, 'train_saver': None}\n\n    def _model_fn_scaffold(features, labels, mode):\n      _, _ = features, labels\n      tf.Variable(1., name='weight')\n\n      scores = tf.constant([3.])\n      if mode == ModeKeys.PREDICT:\n        savers['predict_saver'] = get_mock_saver()\n        scaffold = tf.train.Scaffold(saver=savers['predict_saver'])\n      elif mode == ModeKeys.TRAIN:\n        savers['train_saver'] = get_mock_saver()\n        scaffold = tf.train.Scaffold(saver=savers['train_saver'])\n      else:\n        scaffold = tf.train.Scaffold()\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          predictions=tf.constant([[1.]]),\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          scaffold=scaffold,\n          export_outputs={'test': export_lib.ClassificationOutput(scores)})\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_scaffold)\n    est.train(dummy_input_fn, steps=1)\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),\n        ModeKeys.EVAL: _get_supervised_input_receiver_fn(),\n        ModeKeys.PREDICT: _get_serving_input_receiver_fn()\n    }\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    est.experimental_export_all_saved_models(export_dir_base,\n                                             input_receiver_fn_map)\n\n    self.assertTrue(savers['train_saver'].restore.called)\n    self.assertEqual(savers['train_saver'].export_meta_graph.call_count, 1)\n    self.assertEqual(savers['train_saver'].save.call_count, 1)\n\n    self.assertTrue(savers['predict_saver'].restore.called)\n    self.assertEqual(savers['predict_saver'].export_meta_graph.call_count, 1)\n    self.assertEqual(savers['predict_saver'].save.call_count, 0)\n\n  def test_scaffold_is_used_for_local_init(self):\n    tmpdir = tempfile.mkdtemp()\n\n    def _model_fn_scaffold(features, labels, mode):\n      _, _ = features, labels\n      my_int = tf.Variable(\n          1, name='my_int', collections=[tf.GraphKeys.LOCAL_VARIABLES])\n      _ = training.get_or_create_steps_per_run_variable()\n      scores = tf.constant([3.])\n      with tf.control_dependencies([\n          tf.initializers.local_variables(),\n          tf.initializers.tables_initializer()\n      ]):\n        assign_op = tf.assign(my_int, 12345)\n\n      # local_initSop must be an Operation, not a Tensor.\n      custom_local_init_op = tf.group(assign_op)\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          predictions=tf.constant([[1.]]),\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          scaffold=tf.train.Scaffold(local_init_op=custom_local_init_op),\n          export_outputs={'test': export_lib.ClassificationOutput(scores)})\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_scaffold)\n    est.train(dummy_input_fn, steps=1)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        serving_input_receiver_fn)\n\n    # Restore, to validate that the custom local_init_op runs.\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.SERVING], export_dir)\n        my_int = graph.get_tensor_by_name('my_int:0')\n        my_int_value = sess.run(my_int)\n        self.assertEqual(12345, my_int_value)\n\n  def test_scaffold_is_used_for_local_init_multiple_modes(self):\n    tmpdir = tempfile.mkdtemp()\n\n    def _model_fn_scaffold(features, labels, mode):\n      _, _ = features, labels\n      my_int = tf.Variable(\n          1, name='my_int', collections=[tf.GraphKeys.LOCAL_VARIABLES])\n      scores = tf.constant([3.])\n      with tf.control_dependencies([\n          tf.initializers.local_variables(),\n          tf.initializers.tables_initializer()\n      ]):\n        assign_op = tf.assign(my_int, 12345)\n\n      custom_local_init_op = None\n      if mode == ModeKeys.PREDICT:\n        # local_initSop must be an Operation, not a Tensor.\n        custom_local_init_op = tf.group(assign_op)\n\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          predictions=tf.constant([[1.]]),\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          scaffold=tf.train.Scaffold(local_init_op=custom_local_init_op),\n          export_outputs={'test': export_lib.ClassificationOutput(scores)})\n\n    est = estimator.EstimatorV2(model_fn=_model_fn_scaffold)\n    est.train(dummy_input_fn, steps=1)\n    input_receiver_fn_map = {\n        ModeKeys.TRAIN: _get_supervised_input_receiver_fn(),\n        ModeKeys.EVAL: _get_supervised_input_receiver_fn(),\n        ModeKeys.PREDICT: _get_serving_input_receiver_fn()\n    }\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n    export_dir = est.experimental_export_all_saved_models(\n        export_dir_base, input_receiver_fn_map)\n\n    # Restore, to validate that the custom local_init_op runs.\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.SERVING], export_dir)\n        my_int = graph.get_tensor_by_name('my_int:0')\n        my_int_value = sess.run(my_int)\n        self.assertEqual(12345, my_int_value)\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        tf.saved_model.load(sess, [tf.saved_model.TRAINING], export_dir)\n        my_int = graph.get_tensor_by_name('my_int:0')\n        my_int_value = sess.run(my_int)\n        self.assertEqual(1, my_int_value)\n\n  def test_features_labels_mode(self):\n    given_features = {'test-features': tf.constant([[1], [1]])}\n\n    def serving_input_receiver_fn():\n      return export_lib.ServingInputReceiver(\n          given_features, tf.placeholder(dtype=tf.dtypes.string))\n\n    def _model_fn(features, labels, mode):\n      self.features, self.labels, self.mode = features, labels, mode\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[0.]]),\n          export_outputs={\n              'test': export_lib.ClassificationOutput(tf.constant([[0.]]))\n          })\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    est.export_saved_model(tempfile.mkdtemp(), serving_input_receiver_fn)\n    self.assertEqual(given_features, self.features)\n    self.assertIsNone(self.labels)\n    self.assertEqual(ModeKeys.PREDICT, self.mode)\n\n  def test_graph_initialization_global_step_and_random_seed(self):\n    expected_random_seed = run_config.RunConfig().tf_random_seed\n\n    def _model_fn(features, labels, mode):\n      _, _, _ = features, labels, mode\n      self.assertIsNotNone(tf.train.get_global_step())\n      self.assertEqual(expected_random_seed, tf.get_default_graph().seed)\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          loss=tf.constant(0.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1),\n          predictions=tf.constant([[0.]]),\n          export_outputs={\n              'test': export_lib.ClassificationOutput(tf.constant([[0.]]))\n          })\n\n    def serving_input_receiver_fn():\n      return export_lib.ServingInputReceiver(\n          {'test-features': tf.constant([[1], [1]])},\n          tf.placeholder(dtype=tf.dtypes.string))\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(dummy_input_fn, steps=1)\n    est.export_saved_model(tempfile.mkdtemp(), serving_input_receiver_fn)\n\n  def test_export_saved_model_respects_soft_placement(self):\n\n    def model_fn_with_a_gpu_op_but_no_kernel(features, labels, mode):\n      _, _ = features, labels\n      table = saver_test_utils.CheckpointedOp(name='v2')\n\n      update_global_step = tf.assign_add(tf.train.get_global_step(), 1)\n      with tf.control_dependencies([update_global_step]):\n        train_op = table.insert('k1', 30.0)\n\n      #  In this test, there are no GPUs available.  The goal is to verify that\n      #  export_saved_model executes nevertheless.\n      with tf.device('/gpu:0'):\n        string_op = tf.strings.as_string(update_global_step)\n\n      with tf.control_dependencies([string_op]):\n        prediction = table.lookup('k1', 0.0)\n\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          predictions=prediction,\n          loss=tf.constant(1.),\n          train_op=train_op,\n          export_outputs={\n              'test': export_lib.PredictOutput({'prediction': prediction})\n          })\n\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(model_fn=model_fn_with_a_gpu_op_but_no_kernel)\n    est.train(input_fn=dummy_input_fn, steps=1)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n\n    export_dir = est.export_saved_model(export_dir_base,\n                                        serving_input_receiver_fn)\n\n    # At this point, if export_saved_model executed with\n    # allow_soft_placement=True, then the GPU-assigned operation was silently\n    # placed on the CPU.  Otherwise, an exception would have been raised\n    # related to the fact that the requested GPU device isn't available.\n\n    # Expectations below assume that export_saved_model has completed normally.\n    self.assertTrue(tf.gfile.Exists(export_dir_base))\n    self.assertTrue(tf.gfile.Exists(export_dir))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('saved_model.pb'))))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('variables'))))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('variables/variables.index'))))\n    self.assertTrue(\n        tf.gfile.Exists(\n            os.path.join(\n                tf.compat.as_bytes(export_dir),\n                tf.compat.as_bytes('variables/variables.data-00000-of-00001'))))\n\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def _validate_strip_default_attrs(self, estimator_cls, export_fn,\n                                    attributes_stripped):\n    \"\"\"Validate estimator export correctly strips/leaves default attributes.\n\n    Args:\n      estimator_cls: `Estimator` or `EstimatorV2`\n      export_fn: a function that takes in an estimator and export arguments, and\n        exports the estimator.\n      attributes_stripped: whether to attributes are expected to be stripped in\n        the MetaGraphDef.\n    \"\"\"\n    est = estimator_cls(model_fn=_model_fn_for_export_tests)\n    est.train(input_fn=dummy_input_fn, steps=1)\n    feature_spec = {\n        'x': tf.io.VarLenFeature(dtype=tf.dtypes.int64),\n        'y': tf.io.VarLenFeature(dtype=tf.dtypes.int64)\n    }\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n\n    # Perform the export, and obtain the MetaGraphDefs\n    tmpdir = tempfile.mkdtemp()\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('export'))\n\n    export_dir = export_fn(est, export_dir_base, serving_input_receiver_fn)\n    saved_model_pb = loader_impl.parse_saved_model(export_dir)\n    self.assertIsNotNone(saved_model_pb)\n    meta_graph_def = [\n        x for x in saved_model_pb.meta_graphs\n        if x.meta_info_def.tags == [tf.saved_model.SERVING]\n    ][0]\n\n    # \"weight\" node in graph is a \"Variable\" Op with 2 default valued attrs.\n    #   o \"container\"    : \"\".\n    #   o \"shared_name\"  : \"\".\n\n    # When default attributes are not stripped, the \"weight\" node should have\n    # attributes \"container\" and \"shared_name\". When default attributes are\n    # stripped, the node should not have these attributes.\n    node_def = test_util.get_node_def_from_graph('weight',\n                                                 meta_graph_def.graph_def)\n    self.assertEqual(attributes_stripped, 'container' not in node_def.attr)\n    self.assertEqual(attributes_stripped, 'shared_name' not in node_def.attr)\n\n    # Clean up.\n    tf.gfile.DeleteRecursively(tmpdir)\n\n  def test_export_saved_model_proto_strip_default_attrs(self):\n    # Test deprecated export_savedmodel to ensure that V1 behavior is consistent\n    self._validate_strip_default_attrs(\n        estimator.Estimator,\n        lambda e, *args: e.export_savedmodel(*args, strip_default_attrs=True),\n        True)\n    self._validate_strip_default_attrs(\n        estimator.Estimator,\n        lambda e, *args: e.export_savedmodel(*args, strip_default_attrs=False),\n        False)\n\n    # Make sure that export_saved_model strips the default attributes.\n    self._validate_strip_default_attrs(\n        estimator.Estimator, lambda e, *args: e.export_saved_model(*args), True)\n    self._validate_strip_default_attrs(\n        estimator.EstimatorV2, lambda e, *args: e.export_saved_model(*args),\n        True)\n\n  def test_export_saved_model_no_export_outputs(self):\n    \"\"\"Ensure that an EstimatorSpec without outputs defined can be exported.\"\"\"\n\n    def _model_fn(features, labels, mode):\n      _, _ = features, labels\n      tf.Variable(1., name='weight')\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          predictions=tf.constant(10.),\n          loss=tf.constant(1.),\n          train_op=tf.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n    tmpdir = tempfile.mkdtemp()\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    est.train(input_fn=dummy_input_fn, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        tf.compat.as_bytes(tmpdir), tf.compat.as_bytes('no_export_outputs'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        _get_serving_input_receiver_fn())\n\n    # Check that all the files are in the right places.\n    self.assertTrue(tf.gfile.Exists(export_dir_base))\n    self._validate_exported_files(export_dir)\n\n    # Restore, to validate that the export was well-formed.\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        meta_graph = tf.saved_model.load(sess, [tf.saved_model.SERVING],\n                                         export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertTrue('weight' in graph_ops)\n\n        sig_def = meta_graph.signature_def\n        self.assertEqual(len(sig_def), 1)\n        sig_outputs = sig_def[\n            tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs\n        self.assertEqual(sig_outputs['output'].name, 'Const:0')\n\n  def test_export_from_warm_start(self):\n\n    def _make_model_fn(x):\n\n      def _variable_creating_model_fn(features, labels, mode):\n        _, _ = features, labels\n        tf.get_variable('x', initializer=x)\n        global_step = tf.train.get_global_step()\n        return model_fn_lib.EstimatorSpec(\n            mode,\n            predictions=tf.constant(1.),\n            loss=tf.constant(1.),\n            train_op=tf.assign_add(global_step, 1))\n\n      return _variable_creating_model_fn\n\n    est = estimator.EstimatorV2(model_fn=_make_model_fn(42.))\n    est.train(dummy_input_fn, steps=10)\n\n    warm_started_est = estimator.EstimatorV2(\n        model_fn=_make_model_fn(36.), warm_start_from=est.model_dir)\n    saved_model_dir = warm_started_est.export_saved_model(\n        tempfile.mkdtemp(), _get_serving_input_receiver_fn())\n    variable_dir = path_helpers.get_variables_path(saved_model_dir)\n    self.assertEqual(42., tf.train.load_variable(variable_dir, 'x'))\n\n  def test_export_saved_model_symbol_deprecated(self):\n    est = estimator.EstimatorV2(model_fn=_model_fn_for_export_tests)\n    with self.assertRaisesRegexp(AttributeError,\n                                 'Please use `export_saved_model`'):\n      est.export_savedmodel\n\n\nclass EstimatorHookOrderingTest(tf.test.TestCase):\n\n  def testCustomHooksAreCalledBeforeNanTensorHook(self):\n\n    def nan_making_model_fn(mode, features, labels):\n      \"\"\"A graph that generates NaN's for testing.\"\"\"\n      del features, labels\n\n      global_step = tf.Variable(0, dtype=tf.dtypes.int64, name='global_step')\n      inc_global_step = tf.assign_add(global_step, 1)\n      nan_const = tf.constant(np.nan, dtype=tf.dtypes.float32)\n      loss = tf.cond(inc_global_step > 1, lambda: nan_const, lambda: 1.0)\n\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          predictions=global_step.read_value(),\n          loss=loss,\n          train_op=inc_global_step)\n\n    def empty_input_fn():\n      return dict(), None\n\n    class AfterRunCountingHook(tf.train.SessionRunHook):\n      \"\"\"Hooks that counts the number of times after_run() is called.\"\"\"\n\n      def __init__(self):\n        self.after_run_count = 0\n\n      def after_run(self, run_context, run_values):\n        del run_context, run_values\n        self.after_run_count += 1\n\n    test_hook = AfterRunCountingHook()\n    est = estimator.EstimatorV2(model_fn=nan_making_model_fn)\n    with self.assertRaises(tf.train.NanLossDuringTrainingError):\n      est.train(input_fn=empty_input_fn, steps=2, hooks=[test_hook])\n    self.assertEqual(2, test_hook.after_run_count)\n\n\nclass EstimatorIntegrationTest(tf.test.TestCase):\n\n  def test_complete_flow_with_a_simple_linear_model(self):\n\n    def _model_fn(features, labels, mode):\n      predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n          features['x'], 1, kernel_initializer=tf.initializers.zeros())\n      export_outputs = {'predictions': export_lib.RegressionOutput(predictions)}\n\n      if mode == ModeKeys.PREDICT:\n        return model_fn_lib.EstimatorSpec(\n            mode, predictions=predictions, export_outputs=export_outputs)\n\n      loss = tf_keras_v1.losses.MeanSquaredError()(labels, predictions)\n      train_op = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(\n          loss, tf.train.get_global_step())\n      mean = tf_keras_v1.metrics.Mean()\n      mean.update_state(loss)\n      eval_metric_ops = {\n          'absolute_error': tf.metrics.mean_absolute_error(labels, predictions),\n          'mean': mean,\n      }\n\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          predictions=predictions,\n          loss=loss,\n          train_op=train_op,\n          eval_metric_ops=eval_metric_ops,\n          export_outputs=export_outputs)\n\n    est = estimator.EstimatorV2(model_fn=_model_fn)\n    data = np.linspace(0., 1., 100, dtype=np.float32).reshape(-1, 1)\n\n    # TRAIN\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=50, num_epochs=None, shuffle=True)\n    est.train(train_input_fn, steps=200)\n\n    # EVALUATE\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=50, num_epochs=1, shuffle=True)\n    scores = est.evaluate(eval_input_fn)\n    self.assertEqual(200, scores['global_step'])\n    self.assertGreater(0.1, scores['absolute_error'])\n    self.assertAlmostEqual(4.4e-14, scores['mean'], places=2)\n\n    # PREDICT\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=None, batch_size=10, num_epochs=1, shuffle=False)\n    predictions = list(est.predict(predict_input_fn))\n    self.assertAllClose(data, predictions, atol=0.01)\n\n    # EXPORT\n    feature_spec = {'x': tf.io.FixedLenFeature([1], tf.dtypes.float32)}\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n    export_dir = est.export_saved_model(tempfile.mkdtemp(),\n                                        serving_input_receiver_fn)\n    self.assertTrue(tf.gfile.Exists(export_dir))\n\n\nclass EstimatorInputContextTest(tf.test.TestCase):\n\n  def test_with_input_fn(self):\n    total_batch_size = 10\n    num_shards = 2\n\n    def _input_with_context(input_context):\n      batch_size = total_batch_size // num_shards\n      self.assertEqual('DummyInputContext', input_context.name)\n      self.assertEqual(batch_size, input_context.batch_size)\n      return tf.data.Dataset.from_tensors(([1.], [2.]))\n\n    def _input_without_context():\n      return tf.data.Dataset.from_tensors(([1.], [2.]))\n\n    class DummyInputContext(object):\n\n      def __init__(self, n_shards, total_bs):\n        self._name = 'DummyInputContext'\n        self._num_shards = n_shards\n        self._total_batch_size = total_bs\n\n      @property\n      def name(self):\n        return self._name\n\n      @property\n      def batch_size(self):\n        return self._total_batch_size // self._num_shards\n\n    # This class is the mock for DistributionStrategy. It only overrides\n    # the make_input_fn_iterator method.\n    class DummyDistributionStrategy(object):\n\n      def __init__(self, n_shards):\n        self._num_shards = n_shards\n\n      def make_input_fn_iterator(self, input_fn):\n        input_context = DummyInputContext(num_shards, total_batch_size)\n        return input_fn(input_context)\n\n    distribution = DummyDistributionStrategy(num_shards)\n    est = estimator.EstimatorV2(model_fn=dummy_model_fn)\n    # We only test the `input_fn` instead of calling `Estimator.train`\n    est._get_iterator_from_input_fn(_input_with_context, None, distribution)  # pylint: disable=protected-access\n    est._get_iterator_from_input_fn(_input_without_context, None, distribution)  # pylint: disable=protected-access\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/export/__init__.py",
    "content": ""
  },
  {
    "path": "tensorflow_estimator/python/estimator/export/export.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Configuration and utilities for receiving inputs at serving time.\n\nExtends the export utils defined in core TensorFlow.\n\nPlease avoid importing this file directly, all of the public functions have\nbeen exported to export_lib.py.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.saved_model.model_utils import export_utils\nfrom tensorflow.python.saved_model.model_utils.export_utils import SINGLE_FEATURE_DEFAULT_NAME\nfrom tensorflow.python.saved_model.model_utils.export_utils import SINGLE_LABEL_DEFAULT_NAME\nfrom tensorflow.python.saved_model.model_utils.export_utils import SINGLE_RECEIVER_DEFAULT_NAME\nfrom tensorflow_estimator.python.estimator import util\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\n_SINGLE_TENSOR_DEFAULT_NAMES = {\n    'feature': SINGLE_FEATURE_DEFAULT_NAME,\n    'label': SINGLE_LABEL_DEFAULT_NAME,\n    'receiver_tensor': SINGLE_RECEIVER_DEFAULT_NAME,\n    'receiver_tensors_alternative': SINGLE_RECEIVER_DEFAULT_NAME\n}\n\n\ndef wrap_and_check_input_tensors(tensors, field_name, allow_int_keys=False):\n  \"\"\"Ensure that tensors is a dict of str to Tensor mappings.\n\n  Args:\n    tensors: dict of `str` (or `int`s if `allow_int_keys=True`) to `Tensors`, or\n      a single `Tensor`.\n    field_name: name of the member field of `ServingInputReceiver` whose value\n      is being passed to `tensors`.\n    allow_int_keys: If set to true, the `tensor` dict keys may also be `int`s.\n\n  Returns:\n    dict of str to Tensors; this is the original dict if one was passed, or\n    the original tensor wrapped in a dictionary.\n\n  Raises:\n    ValueError: if tensors is None, or has non-string keys,\n      or non-Tensor values\n  \"\"\"\n  if tensors is None:\n    raise ValueError('{}s must be defined.'.format(field_name))\n  if not isinstance(tensors, dict):\n    tensors = {_SINGLE_TENSOR_DEFAULT_NAMES[field_name]: tensors}\n  for name, tensor in tensors.items():\n    _check_tensor_key(name, error_label=field_name, allow_ints=allow_int_keys)\n    _check_tensor(tensor, name, error_label=field_name)\n  return tensors\n\n\ndef _check_tensor(tensor, name, error_label='feature'):\n  \"\"\"Check that passed `tensor` is a Tensor or SparseTensor or RaggedTensor.\"\"\"\n  if not (isinstance(tensor, tf.Tensor) or\n          isinstance(tensor, tf.sparse.SparseTensor) or\n          isinstance(tensor, tf.RaggedTensor)):\n    fmt_name = ' {}'.format(name) if name else ''\n    value_error = ValueError('{}{} must be a Tensor, SparseTensor, or '\n                             'RaggedTensor.'.format(error_label, fmt_name))\n    # NOTE(ericmc): This if-else block is a specific carve-out for\n    # LabeledTensor, which has a `.tensor` attribute and which is\n    # convertible to tf.Tensor via ops.convert_to_tensor.\n    # Allowing all types convertible to tf.Tensor is considered by soergel@\n    # to be too permissive.\n    # TODO(soergel): accept any type convertible to Tensor,\n    # as in cl/193238295 snapshot #6.\n    if hasattr(tensor, 'tensor'):\n      try:\n        ops.convert_to_tensor(tensor)\n      except TypeError:\n        raise value_error\n    else:\n      raise value_error\n\n\ndef _check_tensor_key(name, error_label='feature', allow_ints=False):\n  if not isinstance(name, six.string_types):\n    if not allow_ints:\n      raise ValueError('{} keys must be strings: {}.'.format(error_label, name))\n    elif not isinstance(name, six.integer_types):\n      raise ValueError('{} keys must be strings or ints: {}.'.format(\n          error_label, name))\n\n\n@estimator_export('estimator.export.ServingInputReceiver')\nclass ServingInputReceiver(\n    collections.namedtuple(\n        'ServingInputReceiver',\n        ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):\n  \"\"\"A return type for a serving_input_receiver_fn.\n\n  Attributes:\n    features: A `Tensor`, `SparseTensor`, or dict of string or int to `Tensor`\n      or `SparseTensor`, specifying the features to be passed to the model.\n      Note: if `features` passed is not a dict, it will be wrapped in a dict\n        with a single entry, using 'feature' as the key.  Consequently, the\n        model\n      must accept a feature dict of the form {'feature': tensor}.  You may use\n        `TensorServingInputReceiver` if you want the tensor to be passed as is.\n    receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`\n      or `SparseTensor`, specifying input nodes where this receiver expects to\n      be fed by default.  Typically, this is a single placeholder expecting\n      serialized `tf.Example` protos.\n    receiver_tensors_alternatives: a dict of string to additional groups of\n      receiver tensors, each of which may be a `Tensor`, `SparseTensor`, or dict\n      of string to `Tensor` or`SparseTensor`. These named receiver tensor\n      alternatives generate additional serving signatures, which may be used to\n      feed inputs at different points within the input receiver subgraph.  A\n      typical usage is to allow feeding raw feature `Tensor`s *downstream* of\n      the tf.parse_example() op. Defaults to None.\n  \"\"\"\n\n  def __new__(cls,\n              features,\n              receiver_tensors,\n              receiver_tensors_alternatives=None):\n    features = wrap_and_check_input_tensors(\n        features, 'feature', allow_int_keys=True)\n\n    receiver_tensors = wrap_and_check_input_tensors(receiver_tensors,\n                                                    'receiver_tensor')\n\n    if receiver_tensors_alternatives is not None:\n      if not isinstance(receiver_tensors_alternatives, dict):\n        raise ValueError(\n            'receiver_tensors_alternatives must be a dict: {}.'.format(\n                receiver_tensors_alternatives))\n      for alternative_name, receiver_tensors_alt in (\n          six.iteritems(receiver_tensors_alternatives)):\n        # Updating dict during iteration is OK in this case.\n        receiver_tensors_alternatives[alternative_name] = (\n            wrap_and_check_input_tensors(receiver_tensors_alt,\n                                         'receiver_tensors_alternative'))\n\n    return super(ServingInputReceiver, cls).__new__(\n        cls,\n        features=features,\n        receiver_tensors=receiver_tensors,\n        receiver_tensors_alternatives=receiver_tensors_alternatives)\n\n\n@estimator_export('estimator.export.TensorServingInputReceiver')\nclass TensorServingInputReceiver(\n    collections.namedtuple(\n        'TensorServingInputReceiver',\n        ['features', 'receiver_tensors', 'receiver_tensors_alternatives'])):\n  \"\"\"A return type for a serving_input_receiver_fn.\n\n  This is for use with models that expect a single `Tensor` or `SparseTensor`\n  as an input feature, as opposed to a dict of features.\n\n  The normal `ServingInputReceiver` always returns a feature dict, even if it\n  contains only one entry, and so can be used only with models that accept such\n  a dict.  For models that accept only a single raw feature, the\n  `serving_input_receiver_fn` provided to `Estimator.export_saved_model()`\n  should return this `TensorServingInputReceiver` instead.  See:\n  https://github.com/tensorflow/tensorflow/issues/11674\n\n  Note that the receiver_tensors and receiver_tensor_alternatives arguments\n  will be automatically converted to the dict representation in either case,\n  because the SavedModel format requires each input `Tensor` to have a name\n  (provided by the dict key).\n\n  Attributes:\n    features: A single `Tensor` or `SparseTensor`, representing the feature to\n      be passed to the model.\n    receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`\n      or `SparseTensor`, specifying input nodes where this receiver expects to\n      be fed by default.  Typically, this is a single placeholder expecting\n      serialized `tf.Example` protos.\n    receiver_tensors_alternatives: a dict of string to additional groups of\n      receiver tensors, each of which may be a `Tensor`, `SparseTensor`, or dict\n      of string to `Tensor` or`SparseTensor`. These named receiver tensor\n      alternatives generate additional serving signatures, which may be used to\n      feed inputs at different points within the input receiver subgraph.  A\n      typical usage is to allow feeding raw feature `Tensor`s *downstream* of\n      the tf.parse_example() op. Defaults to None.\n  \"\"\"\n\n  def __new__(cls,\n              features,\n              receiver_tensors,\n              receiver_tensors_alternatives=None):\n    if features is None:\n      raise ValueError('features must be defined.')\n    _check_tensor(features, None)\n\n    receiver = ServingInputReceiver(\n        features=features,\n        receiver_tensors=receiver_tensors,\n        receiver_tensors_alternatives=receiver_tensors_alternatives)\n\n    return super(TensorServingInputReceiver, cls).__new__(\n        cls,\n        features=receiver.features[SINGLE_FEATURE_DEFAULT_NAME],\n        receiver_tensors=receiver.receiver_tensors,\n        receiver_tensors_alternatives=receiver.receiver_tensors_alternatives)\n\n\nclass UnsupervisedInputReceiver(ServingInputReceiver):\n  \"\"\"A return type for a training_input_receiver_fn or eval_input_receiver_fn.\n\n  This differs from SupervisedInputReceiver in that it does not require a set\n  of labels.\n\n  Attributes:\n    features: A `Tensor`, `SparseTensor`, or dict of string to `Tensor` or\n      `SparseTensor`, specifying the features to be passed to the model.\n    receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`\n      or `SparseTensor`, specifying input nodes where this receiver expects to\n      be fed by default.  Typically, this is a single placeholder expecting\n      serialized `tf.Example` protos.\n  \"\"\"\n\n  def __new__(cls, features, receiver_tensors):\n    return super(UnsupervisedInputReceiver, cls).__new__(\n        cls,\n        features=features,\n        receiver_tensors=receiver_tensors,\n        receiver_tensors_alternatives=None)\n\n\nclass SupervisedInputReceiver(\n    collections.namedtuple('SupervisedInputReceiver',\n                           ['features', 'labels', 'receiver_tensors'])):\n  \"\"\"A return type for a training_input_receiver_fn or eval_input_receiver_fn.\n\n  This differs from a ServingInputReceiver in that (1) this receiver expects\n  a set of labels to be passed in with features, and (2) this receiver does\n  not support receiver_tensors_alternatives, which are primarily used for\n  serving.\n\n  The expected return values are:\n    features: A `Tensor`, `SparseTensor`, or dict of string or int to `Tensor`\n      or `SparseTensor`, specifying the features to be passed to the model.\n    labels: A `Tensor`, `SparseTensor`, or dict of string or int to `Tensor` or\n      `SparseTensor`, specifying the labels to be passed to the model.\n    receiver_tensors: A `Tensor`, `SparseTensor`, or dict of string to `Tensor`\n      or `SparseTensor`, specifying input nodes where this receiver expects to\n      be fed by default.  Typically, this is a single placeholder expecting\n      serialized `tf.Example` protos.\n\n  \"\"\"\n\n  def __new__(cls, features, labels, receiver_tensors):\n    # Both features and labels can be dicts or raw tensors.\n    # wrap_and_check_input_tensors is called here only to validate the tensors.\n    # The wrapped dict that is returned is deliberately discarded.\n    wrap_and_check_input_tensors(features, 'feature', allow_int_keys=True)\n    wrap_and_check_input_tensors(labels, 'label', allow_int_keys=True)\n\n    receiver_tensors = wrap_and_check_input_tensors(receiver_tensors,\n                                                    'receiver_tensor')\n\n    return super(SupervisedInputReceiver, cls).__new__(\n        cls,\n        features=features,\n        labels=labels,\n        receiver_tensors=receiver_tensors)\n\n\n@estimator_export('estimator.export.build_parsing_serving_input_receiver_fn')\ndef build_parsing_serving_input_receiver_fn(feature_spec,\n                                            default_batch_size=None):\n  \"\"\"Build a serving_input_receiver_fn expecting fed tf.Examples.\n\n  Creates a serving_input_receiver_fn that expects a serialized tf.Example fed\n  into a string placeholder.  The function parses the tf.Example according to\n  the provided feature_spec, and returns all parsed Tensors as features.\n\n  Args:\n    feature_spec: a dict of string to `VarLenFeature`/`FixedLenFeature`.\n    default_batch_size: the number of query examples expected per batch. Leave\n      unset for variable batch size (recommended).\n\n  Returns:\n    A serving_input_receiver_fn suitable for use in serving.\n  \"\"\"\n\n  def serving_input_receiver_fn():\n    \"\"\"An input_fn that expects a serialized tf.Example.\"\"\"\n    serialized_tf_example = tf.compat.v1.placeholder(\n        dtype=tf.dtypes.string,\n        shape=[default_batch_size],\n        name='input_example_tensor')\n    receiver_tensors = {'examples': serialized_tf_example}\n    features = tf.compat.v1.io.parse_example(serialized_tf_example,\n                                             feature_spec)\n    return ServingInputReceiver(features, receiver_tensors)\n\n  return serving_input_receiver_fn\n\n\ndef _placeholder_from_tensor(t, default_batch_size=None):\n  \"\"\"Creates a placeholder that matches the dtype and shape of passed tensor.\n\n  Args:\n    t: Tensor or EagerTensor\n    default_batch_size: the number of query examples expected per batch. Leave\n      unset for variable batch size (recommended).\n\n  Returns:\n    Placeholder that matches the passed tensor.\n  \"\"\"\n  batch_shape = tf.TensorShape([default_batch_size])\n  shape = batch_shape.concatenate(t.get_shape()[1:])\n\n  # Reuse the feature tensor's op name (t.op.name) for the placeholder,\n  # excluding the index from the tensor's name (t.name):\n  # t.name = \"%s:%d\" % (t.op.name, t._value_index)\n  try:\n    name = t.op.name\n  except AttributeError:\n    # In Eager mode, tensors don't have ops or names, and while they do have\n    # IDs, those are not maintained across runs. The name here is used\n    # primarily for debugging, and is not critical to the placeholder.\n    # So, in order to make this Eager-compatible, continue with an empty\n    # name if none is available.\n    name = None\n\n  return tf.compat.v1.placeholder(dtype=t.dtype, shape=shape, name=name)\n\n\ndef _placeholders_from_receiver_tensors_dict(input_vals,\n                                             default_batch_size=None):\n  return {\n      name: _placeholder_from_tensor(t, default_batch_size)\n      for name, t in input_vals.items()\n  }\n\n\n@estimator_export('estimator.export.build_raw_serving_input_receiver_fn')\ndef build_raw_serving_input_receiver_fn(features, default_batch_size=None):\n  \"\"\"Build a serving_input_receiver_fn expecting feature Tensors.\n\n  Creates an serving_input_receiver_fn that expects all features to be fed\n  directly.\n\n  Args:\n    features: a dict of string to `Tensor`.\n    default_batch_size: the number of query examples expected per batch. Leave\n      unset for variable batch size (recommended).\n\n  Returns:\n    A serving_input_receiver_fn.\n  \"\"\"\n\n  def serving_input_receiver_fn():\n    \"\"\"A serving_input_receiver_fn that expects features to be fed directly.\"\"\"\n    receiver_tensors = _placeholders_from_receiver_tensors_dict(\n        features, default_batch_size)\n    return ServingInputReceiver(receiver_tensors, receiver_tensors)\n\n  return serving_input_receiver_fn\n\n\n@estimator_export(\n    'estimator.experimental.build_raw_supervised_input_receiver_fn')\ndef build_raw_supervised_input_receiver_fn(features,\n                                           labels,\n                                           default_batch_size=None):\n  \"\"\"Build a supervised_input_receiver_fn for raw features and labels.\n\n  This function wraps tensor placeholders in a supervised_receiver_fn\n  with the expectation that the features and labels appear precisely as\n  the model_fn expects them. Features and labels can therefore be dicts of\n  tensors, or raw tensors.\n\n  Args:\n    features: a dict of string to `Tensor` or `Tensor`.\n    labels: a dict of string to `Tensor` or `Tensor`.\n    default_batch_size: the number of query examples expected per batch. Leave\n      unset for variable batch size (recommended).\n\n  Returns:\n    A supervised_input_receiver_fn.\n\n  Raises:\n    ValueError: if features and labels have overlapping keys.\n  \"\"\"\n  # Check for overlapping keys before beginning.\n  try:\n    feat_keys = features.keys()\n  except AttributeError:\n    feat_keys = [SINGLE_RECEIVER_DEFAULT_NAME]\n  try:\n    label_keys = labels.keys()\n  except AttributeError:\n    label_keys = [SINGLE_LABEL_DEFAULT_NAME]\n\n  overlap_keys = set(feat_keys) & set(label_keys)\n  if overlap_keys:\n    raise ValueError('Features and labels must have distinct keys. '\n                     'Found overlapping keys: {}'.format(overlap_keys))\n\n  def supervised_input_receiver_fn():\n    \"\"\"A receiver_fn that expects pass-through features and labels.\"\"\"\n    if not isinstance(features, dict):\n      features_cp = _placeholder_from_tensor(features, default_batch_size)\n      receiver_features = {SINGLE_RECEIVER_DEFAULT_NAME: features_cp}\n    else:\n      receiver_features = _placeholders_from_receiver_tensors_dict(\n          features, default_batch_size)\n      features_cp = receiver_features\n\n    if not isinstance(labels, dict):\n      labels_cp = _placeholder_from_tensor(labels, default_batch_size)\n      receiver_labels = {SINGLE_LABEL_DEFAULT_NAME: labels_cp}\n    else:\n      receiver_labels = _placeholders_from_receiver_tensors_dict(\n          labels, default_batch_size)\n      labels_cp = receiver_labels\n\n    receiver_tensors = dict(receiver_features)\n    receiver_tensors.update(receiver_labels)\n    return SupervisedInputReceiver(features_cp, labels_cp, receiver_tensors)\n\n  return supervised_input_receiver_fn\n\n\ndef build_supervised_input_receiver_fn_from_input_fn(input_fn, **input_fn_args):\n  \"\"\"Get a function that returns a SupervisedInputReceiver matching an input_fn.\n\n  Note that this function calls the input_fn in a local graph in order to\n  extract features and labels. Placeholders are then created from those\n  features and labels in the default graph.\n\n  Args:\n    input_fn: An Estimator input_fn, which is a function that returns one of:\n      * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a tuple\n        (features, labels) with same constraints as below.\n      * A tuple (features, labels): Where `features` is a `Tensor` or a\n        dictionary of string feature name to `Tensor` and `labels` is a `Tensor`\n        or a dictionary of string label name to `Tensor`. Both `features` and\n        `labels` are consumed by `model_fn`. They should satisfy the expectation\n        of `model_fn` from inputs.\n    **input_fn_args: set of kwargs to be passed to the input_fn. Note that these\n      will not be checked or validated here, and any errors raised by the\n      input_fn will be thrown to the top.\n\n  Returns:\n    A function taking no arguments that, when called, returns a\n    SupervisedInputReceiver. This function can be passed in as part of the\n    input_receiver_map when exporting SavedModels from Estimator with multiple\n    modes.\n  \"\"\"\n  # Wrap the input_fn call in a graph to prevent sullying the default namespace\n  with tf.Graph().as_default():\n    result = input_fn(**input_fn_args)\n    features, labels, _ = util.parse_input_fn_result(result)\n  # Placeholders are created back in the default graph.\n  return build_raw_supervised_input_receiver_fn(features, labels)\n\n\n### Below utilities are specific to SavedModel exports.\n# TODO(kathywu): Rename all references to use the original definition in\n# model_utils, or estimator/export/export_lib.py if other estimator export\n# functions are used.\nbuild_all_signature_defs = export_utils.build_all_signature_defs\nget_temp_export_dir = export_utils.get_temp_export_dir\nget_timestamped_export_dir = export_utils.get_timestamped_export_dir\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/export/export_lib.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"All public utility methods for exporting Estimator to SavedModel.\n\nThis file includes functions and constants from core (model_utils) and export.py\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# pylint: disable=unused-import,line-too-long, wildcard-import\nfrom tensorflow.python.saved_model.model_utils import build_all_signature_defs\nfrom tensorflow.python.saved_model.model_utils import export_outputs_for_mode\nfrom tensorflow.python.saved_model.model_utils import EXPORT_TAG_MAP\nfrom tensorflow.python.saved_model.model_utils import get_export_outputs\nfrom tensorflow.python.saved_model.model_utils import get_temp_export_dir\nfrom tensorflow.python.saved_model.model_utils import get_timestamped_export_dir\nfrom tensorflow.python.saved_model.model_utils import SIGNATURE_KEY_MAP\nfrom tensorflow.python.saved_model.model_utils.export_output import _SupervisedOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import ClassificationOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import EvalOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import ExportOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import PredictOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import RegressionOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import TrainOutput\nfrom tensorflow_estimator.python.estimator.export.export import build_parsing_serving_input_receiver_fn\nfrom tensorflow_estimator.python.estimator.export.export import build_raw_serving_input_receiver_fn\nfrom tensorflow_estimator.python.estimator.export.export import build_raw_supervised_input_receiver_fn\nfrom tensorflow_estimator.python.estimator.export.export import build_supervised_input_receiver_fn_from_input_fn\nfrom tensorflow_estimator.python.estimator.export.export import ServingInputReceiver\nfrom tensorflow_estimator.python.estimator.export.export import SupervisedInputReceiver\nfrom tensorflow_estimator.python.estimator.export.export import TensorServingInputReceiver\nfrom tensorflow_estimator.python.estimator.export.export import UnsupervisedInputReceiver\nfrom tensorflow_estimator.python.estimator.export.export import wrap_and_check_input_tensors\n# pylint: enable=unused-import,line-too-long, wildcard-import\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/export/export_output.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Classes for different types of export output.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# pylint: disable=unused-import\nfrom tensorflow.python.saved_model.model_utils.export_output import _SupervisedOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import ClassificationOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import EvalOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import ExportOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import PredictOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import RegressionOutput\nfrom tensorflow.python.saved_model.model_utils.export_output import TrainOutput\n# pylint: enable=unused-import\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\nestimator_export('estimator.export.ExportOutput')(ExportOutput)\nestimator_export('estimator.export.ClassificationOutput')(ClassificationOutput)\nestimator_export('estimator.export.RegressionOutput')(RegressionOutput)\nestimator_export('estimator.export.PredictOutput')(PredictOutput)\nestimator_export('estimator.export.EvalOutput')(EvalOutput)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/export/export_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for export.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\n\nfrom google.protobuf import text_format\n\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import tensor_shape\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.export import export\n\n\nclass LabeledTensorMock(object):\n  \"\"\"Mock class emulating LabeledTensor.\"\"\"\n\n  def __init__(self):\n    self.tensor = tf.constant([1])\n\n\ndef _convert_labeled_tensor_mock_to_tensor(value, *args, **kwargs):\n  return ops.internal_convert_to_tensor(value.tensor, *args, **kwargs)\n\n\ntf.register_tensor_conversion_function(LabeledTensorMock,\n                                       _convert_labeled_tensor_mock_to_tensor)\n\n\nclass ServingInputReceiverTest(tf.test.TestCase):\n\n  def test_serving_input_receiver_constructor(self):\n    \"\"\"Tests that no errors are raised when input is expected.\"\"\"\n    features = {\n        \"feature0\": tf.constant([0]),\n        u\"feature1\": tf.constant([1]),\n        \"feature2\": tf.sparse.SparseTensor(\n            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),\n        # ints are allowed only in the `features` dict\n        42: tf.constant([3]),\n    }\n    receiver_tensors = {\n        \"example0\": tf.constant([\"test0\"], name=\"example0\"),\n        u\"example1\": tf.constant([\"test1\"], name=\"example1\"),\n    }\n    export.ServingInputReceiver(features, receiver_tensors)\n\n  def test_serving_input_receiver_features_invalid(self):\n    receiver_tensors = {\n        \"example0\": tf.constant([\"test0\"], name=\"example0\"),\n        u\"example1\": tf.constant([\"test1\"], name=\"example1\"),\n    }\n\n    with self.assertRaisesRegexp(ValueError, \"features must be defined\"):\n      export.ServingInputReceiver(\n          features=None, receiver_tensors=receiver_tensors)\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"feature keys must be strings or ints\"):\n      export.ServingInputReceiver(\n          features={42.2: tf.constant([1])}, receiver_tensors=receiver_tensors)\n\n    with self.assertRaisesRegexp(\n        ValueError, \"feature feature1 must be a Tensor, SparseTensor, or \"\n        \"RaggedTensor.\"):\n      export.ServingInputReceiver(\n          features={\"feature1\": [1]}, receiver_tensors=receiver_tensors)\n\n  def test_serving_input_receiver_receiver_tensors_invalid(self):\n    features = {\n        \"feature0\": tf.constant([0]),\n        u\"feature1\": tf.constant([1]),\n        \"feature2\": tf.sparse.SparseTensor(\n            indices=[[0, 0]], values=[1], dense_shape=[1, 1]),\n    }\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"receiver_tensors must be defined\"):\n      export.ServingInputReceiver(features=features, receiver_tensors=None)\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"receiver_tensor keys must be strings\"):\n      export.ServingInputReceiver(\n          features=features,\n          receiver_tensors={1: tf.constant([\"test\"], name=\"example0\")})\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"receiver_tensor example1 must be a Tensor\"):\n      export.ServingInputReceiver(\n          features=features, receiver_tensors={\"example1\": [1]})\n\n  def test_single_feature_single_receiver(self):\n    feature = tf.constant(5)\n    receiver_tensor = tf.constant([\"test\"])\n    input_receiver = export.ServingInputReceiver(feature, receiver_tensor)\n    # single feature is automatically named\n    feature_key, = input_receiver.features.keys()\n    self.assertEqual(\"feature\", feature_key)\n    # single receiver is automatically named\n    receiver_key, = input_receiver.receiver_tensors.keys()\n    self.assertEqual(\"input\", receiver_key)\n\n  def test_multi_feature_single_receiver(self):\n    features = {\"foo\": tf.constant(5), \"bar\": tf.constant(6)}\n    receiver_tensor = tf.constant([\"test\"])\n    _ = export.ServingInputReceiver(features, receiver_tensor)\n\n  def test_multi_feature_multi_receiver(self):\n    features = {\"foo\": tf.constant(5), \"bar\": tf.constant(6)}\n    receiver_tensors = {\"baz\": tf.constant(5), \"qux\": tf.constant(6)}\n    _ = export.ServingInputReceiver(features, receiver_tensors)\n\n  def test_feature_wrong_type(self):\n    feature = \"not a tensor\"\n    receiver_tensor = tf.constant([\"test\"])\n    with self.assertRaises(ValueError):\n      _ = export.ServingInputReceiver(feature, receiver_tensor)\n\n  def test_feature_labeled_tensor(self):\n    feature = LabeledTensorMock()\n    receiver_tensor = tf.constant([\"test\"])\n    _ = export.ServingInputReceiver(feature, receiver_tensor)\n\n  def test_receiver_wrong_type(self):\n    feature = tf.constant(5)\n    receiver_tensor = \"not a tensor\"\n    with self.assertRaises(ValueError):\n      _ = export.ServingInputReceiver(feature, receiver_tensor)\n\n\nclass UnsupervisedInputReceiverTest(tf.test.TestCase):\n\n  # Since this is basically a wrapper around ServingInputReceiver, we only\n  # have a simple sanity check to ensure that it works.\n\n  def test_unsupervised_input_receiver_constructor(self):\n    \"\"\"Tests that no errors are raised when input is expected.\"\"\"\n    features = {\n        \"feature0\":\n            tf.constant([0]),\n        u\"feature1\":\n            tf.constant([1]),\n        \"feature2\":\n            tf.sparse.SparseTensor(\n                indices=[[0, 0]], values=[1], dense_shape=[1, 1]),\n        42:  # ints are allowed only in the `features` dict\n            tf.constant([3]),\n    }\n    receiver_tensors = {\n        \"example0\": tf.constant([\"test0\"], name=\"example0\"),\n        u\"example1\": tf.constant([\"test1\"], name=\"example1\"),\n    }\n    export.UnsupervisedInputReceiver(features, receiver_tensors)\n\n\nclass SupervisedInputReceiverTest(tf.test.TestCase):\n\n  def test_input_receiver_constructor(self):\n    \"\"\"Tests that no errors are raised when input is expected.\"\"\"\n    features = {\n        \"feature0\":\n            tf.constant([0]),\n        u\"feature1\":\n            tf.constant([1]),\n        \"feature2\":\n            tf.sparse.SparseTensor(\n                indices=[[0, 0]], values=[1], dense_shape=[1, 1]),\n        42:  # ints are allowed in the `features` dict\n            tf.constant([3]),\n    }\n    labels = {\n        \"classes\": tf.constant([0] * 100),\n        43:  # ints are allowed in the `labels` dict\n            tf.constant([3]),\n    }\n\n    receiver_tensors = {\n        \"example0\": tf.constant([\"test0\"], name=\"example0\"),\n        u\"example1\": tf.constant([\"test1\"], name=\"example1\"),\n    }\n    export.SupervisedInputReceiver(features, labels, receiver_tensors)\n\n  def test_input_receiver_raw_values(self):\n    \"\"\"Tests that no errors are raised when input is expected.\"\"\"\n    features = {\n        \"feature0\":\n            tf.constant([0]),\n        u\"feature1\":\n            tf.constant([1]),\n        \"feature2\":\n            tf.sparse.SparseTensor(\n                indices=[[0, 0]], values=[1], dense_shape=[1, 1]),\n        42:  # ints are allowed in the `features` dict\n            tf.constant([3]),\n    }\n\n    labels = {\n        \"classes\": tf.constant([0] * 100),\n        43:  # ints are allowed in the `labels` dict\n            tf.constant([3]),\n    }\n\n    receiver_tensors = {\n        \"example0\": tf.constant([\"test0\"], name=\"example0\"),\n        u\"example1\": tf.constant([\"test1\"], name=\"example1\"),\n    }\n    rec = export.SupervisedInputReceiver(features[\"feature2\"], labels,\n                                         receiver_tensors)\n    self.assertIsInstance(rec.features, tf.sparse.SparseTensor)\n\n    rec = export.SupervisedInputReceiver(features, labels[\"classes\"],\n                                         receiver_tensors)\n    self.assertIsInstance(rec.labels, tf.Tensor)\n\n  def test_input_receiver_features_invalid(self):\n    features = tf.constant([0] * 100)\n    labels = tf.constant([0])\n    receiver_tensors = {\n        \"example0\": tf.constant([\"test0\"], name=\"example0\"),\n        u\"example1\": tf.constant([\"test1\"], name=\"example1\"),\n    }\n\n    with self.assertRaisesRegexp(ValueError, \"features must be defined\"):\n      export.SupervisedInputReceiver(\n          features=None, labels=labels, receiver_tensors=receiver_tensors)\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"feature keys must be strings or ints\"):\n      export.SupervisedInputReceiver(\n          features={1.11: tf.constant([1])},\n          labels=labels,\n          receiver_tensors=receiver_tensors)\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"label keys must be strings or ints\"):\n      export.SupervisedInputReceiver(\n          features=features,\n          labels={1.11: tf.constant([1])},\n          receiver_tensors=receiver_tensors)\n\n    with self.assertRaisesRegexp(\n        ValueError, \"feature feature1 must be a Tensor, SparseTensor, or \"\n        \"RaggedTensor.\"):\n      export.SupervisedInputReceiver(\n          features={\"feature1\": [1]},\n          labels=labels,\n          receiver_tensors=receiver_tensors)\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"feature must be a Tensor, SparseTensor, \"\n                                 \"or RaggedTensor.\"):\n      export.SupervisedInputReceiver(\n          features=[1], labels=labels, receiver_tensors=receiver_tensors)\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"label must be a Tensor, SparseTensor, \"\n                                 \"or RaggedTensor.\"):\n      export.SupervisedInputReceiver(\n          features=features, labels=100, receiver_tensors=receiver_tensors)\n\n  def test_input_receiver_receiver_tensors_invalid(self):\n    features = {\n        \"feature0\":\n            tf.constant([0]),\n        u\"feature1\":\n            tf.constant([1]),\n        \"feature2\":\n            tf.sparse.SparseTensor(\n                indices=[[0, 0]], values=[1], dense_shape=[1, 1]),\n    }\n    labels = tf.constant([0])\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"receiver_tensors must be defined\"):\n      export.SupervisedInputReceiver(\n          features=features, labels=labels, receiver_tensors=None)\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"receiver_tensor keys must be strings\"):\n      export.SupervisedInputReceiver(\n          features=features,\n          labels=labels,\n          receiver_tensors={1: tf.constant([\"test\"], name=\"example0\")})\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"receiver_tensor example1 must be a Tensor\"):\n      export.SupervisedInputReceiver(\n          features=features, labels=labels, receiver_tensors={\"example1\": [1]})\n\n  def test_single_feature_single_receiver(self):\n    feature = tf.constant(5)\n    label = tf.constant(5)\n    receiver_tensor = tf.constant([\"test\"])\n    input_receiver = export.SupervisedInputReceiver(feature, label,\n                                                    receiver_tensor)\n\n    # single receiver is automatically named\n    receiver_key, = input_receiver.receiver_tensors.keys()\n    self.assertEqual(\"input\", receiver_key)\n\n  def test_multi_feature_single_receiver(self):\n    features = {\"foo\": tf.constant(5), \"bar\": tf.constant(6)}\n    labels = {\"value\": tf.constant(5)}\n    receiver_tensor = tf.constant([\"test\"])\n    _ = export.SupervisedInputReceiver(features, labels, receiver_tensor)\n\n  def test_multi_feature_multi_receiver(self):\n    features = {\"foo\": tf.constant(5), \"bar\": tf.constant(6)}\n    labels = {\"value\": tf.constant(5)}\n    receiver_tensors = {\"baz\": tf.constant(5), \"qux\": tf.constant(6)}\n    _ = export.SupervisedInputReceiver(features, labels, receiver_tensors)\n\n  def test_feature_labeled_tensor(self):\n    feature = LabeledTensorMock()\n    label = tf.constant(5)\n    receiver_tensor = tf.constant([\"test\"])\n    _ = export.SupervisedInputReceiver(feature, label, receiver_tensor)\n\n\nclass ExportTest(tf.test.TestCase):\n\n  # Calling serving_input_receiver_fn requires graph mode.\n  @test_util.deprecated_graph_mode_only\n  def test_build_parsing_serving_input_receiver_fn(self):\n    feature_spec = {\n        \"int_feature\": tf.io.VarLenFeature(tf.dtypes.int64),\n        \"float_feature\": tf.io.VarLenFeature(tf.dtypes.float32)\n    }\n    serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(\n        feature_spec)\n    with tf.Graph().as_default():\n      serving_input_receiver = serving_input_receiver_fn()\n      self.assertEqual(\n          set([\"int_feature\", \"float_feature\"]),\n          set(serving_input_receiver.features.keys()))\n      self.assertEqual(\n          set([\"examples\"]),\n          set(serving_input_receiver.receiver_tensors.keys()))\n\n      example = example_pb2.Example()\n      text_format.Parse(\n          \"features: { \"\n          \"  feature: { \"\n          \"    key: 'int_feature' \"\n          \"    value: { \"\n          \"      int64_list: { \"\n          \"        value: [ 21, 2, 5 ] \"\n          \"      } \"\n          \"    } \"\n          \"  } \"\n          \"  feature: { \"\n          \"    key: 'float_feature' \"\n          \"    value: { \"\n          \"      float_list: { \"\n          \"        value: [ 525.25 ] \"\n          \"      } \"\n          \"    } \"\n          \"  } \"\n          \"} \", example)\n\n      with self.cached_session() as sess:\n        sparse_result = sess.run(\n            serving_input_receiver.features,\n            feed_dict={\n                serving_input_receiver.receiver_tensors[\"examples\"].name: [\n                    example.SerializeToString()\n                ]\n            })\n        self.assertAllEqual([[0, 0], [0, 1], [0, 2]],\n                            sparse_result[\"int_feature\"].indices)\n        self.assertAllEqual([21, 2, 5], sparse_result[\"int_feature\"].values)\n        self.assertAllEqual([[0, 0]], sparse_result[\"float_feature\"].indices)\n        self.assertAllEqual([525.25], sparse_result[\"float_feature\"].values)\n\n  # Calling serving_input_receiver_fn requires graph mode.\n  @test_util.deprecated_graph_mode_only\n  def test_build_raw_serving_input_receiver_fn_name(self):\n    \"\"\"Test case for issue #12755.\"\"\"\n    f = {\n        \"feature\":\n            tf.compat.v1.placeholder(\n                name=\"feature\", shape=[32], dtype=tf.dtypes.float32)\n    }\n    serving_input_receiver_fn = export.build_raw_serving_input_receiver_fn(f)\n    v = serving_input_receiver_fn()\n    self.assertIsInstance(v, export.ServingInputReceiver)\n\n  # Calling serving_input_receiver_fn requires graph mode.\n  @test_util.deprecated_graph_mode_only\n  def test_build_raw_serving_input_receiver_fn_without_shape(self):\n    \"\"\"Test case for issue #21178.\"\"\"\n    f = {\n        \"feature_1\": tf.compat.v1.placeholder(tf.dtypes.float32),\n        \"feature_2\": tf.compat.v1.placeholder(tf.dtypes.int32)\n    }\n    serving_input_receiver_fn = export.build_raw_serving_input_receiver_fn(f)\n    v = serving_input_receiver_fn()\n    self.assertIsInstance(v, export.ServingInputReceiver)\n    self.assertEqual(tensor_shape.unknown_shape(),\n                     v.receiver_tensors[\"feature_1\"].shape)\n    self.assertEqual(tensor_shape.unknown_shape(),\n                     v.receiver_tensors[\"feature_2\"].shape)\n\n  def test_build_raw_serving_input_receiver_fn(self):\n    features = {\n        \"feature_1\": tf.constant([\"hello\"]),\n        \"feature_2\": tf.constant([42])\n    }\n    serving_input_receiver_fn = export.build_raw_serving_input_receiver_fn(\n        features)\n    with tf.Graph().as_default():\n      serving_input_receiver = serving_input_receiver_fn()\n      self.assertEqual(\n          set([\"feature_1\", \"feature_2\"]),\n          set(serving_input_receiver.features.keys()))\n      self.assertEqual(\n          set([\"feature_1\", \"feature_2\"]),\n          set(serving_input_receiver.receiver_tensors.keys()))\n      self.assertEqual(\n          tf.dtypes.string,\n          serving_input_receiver.receiver_tensors[\"feature_1\"].dtype)\n      self.assertEqual(\n          tf.dtypes.int32,\n          serving_input_receiver.receiver_tensors[\"feature_2\"].dtype)\n\n  def test_build_raw_supervised_input_receiver_fn(self):\n    features = {\n        \"feature_1\": tf.constant([\"hello\"]),\n        \"feature_2\": tf.constant([42])\n    }\n    labels = {\"foo\": tf.constant([5]), \"bar\": tf.constant([6])}\n    input_receiver_fn = export.build_raw_supervised_input_receiver_fn(\n        features, labels)\n    with tf.Graph().as_default():\n      input_receiver = input_receiver_fn()\n      self.assertEqual(\n          set([\"feature_1\", \"feature_2\"]), set(input_receiver.features.keys()))\n      self.assertEqual(set([\"foo\", \"bar\"]), set(input_receiver.labels.keys()))\n      self.assertEqual(\n          set([\"feature_1\", \"feature_2\", \"foo\", \"bar\"]),\n          set(input_receiver.receiver_tensors.keys()))\n      self.assertEqual(tf.dtypes.string,\n                       input_receiver.receiver_tensors[\"feature_1\"].dtype)\n      self.assertEqual(tf.dtypes.int32,\n                       input_receiver.receiver_tensors[\"feature_2\"].dtype)\n\n  def test_build_raw_supervised_input_receiver_fn_raw_tensors(self):\n    features = {\n        \"feature_1\": tf.constant([\"hello\"]),\n        \"feature_2\": tf.constant([42])\n    }\n    labels = {\"foo\": tf.constant([5]), \"bar\": tf.constant([6])}\n    input_receiver_fn1 = export.build_raw_supervised_input_receiver_fn(\n        features[\"feature_1\"], labels)\n    input_receiver_fn2 = export.build_raw_supervised_input_receiver_fn(\n        features[\"feature_1\"], labels[\"foo\"])\n    with tf.Graph().as_default():\n      input_receiver = input_receiver_fn1()\n      self.assertIsInstance(input_receiver.features, tf.Tensor)\n      self.assertEqual(set([\"foo\", \"bar\"]), set(input_receiver.labels.keys()))\n      self.assertEqual(\n          set([\"input\", \"foo\", \"bar\"]),\n          set(input_receiver.receiver_tensors.keys()))\n\n      input_receiver = input_receiver_fn2()\n      self.assertIsInstance(input_receiver.features, tf.Tensor)\n      self.assertIsInstance(input_receiver.labels, tf.Tensor)\n      self.assertEqual(\n          set([\"input\", \"label\"]), set(input_receiver.receiver_tensors.keys()))\n\n  def test_build_raw_supervised_input_receiver_fn_batch_size(self):\n    features = {\n        \"feature_1\": tf.constant([\"hello\"]),\n        \"feature_2\": tf.constant([42])\n    }\n    labels = {\"foo\": tf.constant([5]), \"bar\": tf.constant([6])}\n    input_receiver_fn = export.build_raw_supervised_input_receiver_fn(\n        features, labels, default_batch_size=10)\n    with tf.Graph().as_default():\n      input_receiver = input_receiver_fn()\n      self.assertEqual([10], input_receiver.receiver_tensors[\"feature_1\"].shape)\n      self.assertEqual([10], input_receiver.features[\"feature_1\"].shape)\n\n  def test_build_raw_supervised_input_receiver_fn_overlapping_keys(self):\n    features = {\n        \"feature_1\": tf.constant([\"hello\"]),\n        \"feature_2\": tf.constant([42])\n    }\n    labels = {\"feature_1\": tf.constant([5]), \"bar\": tf.constant([6])}\n    with self.assertRaises(ValueError):\n      export.build_raw_supervised_input_receiver_fn(features, labels)\n\n  def test_build_supervised_input_receiver_fn_from_input_fn(self):\n\n    def dummy_input_fn():\n      return ({\n          \"x\": tf.constant([[1], [1]]),\n          \"y\": tf.constant([\"hello\", \"goodbye\"])\n      }, tf.constant([[1], [1]]))\n\n    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(\n        dummy_input_fn)\n\n    with tf.Graph().as_default():\n      input_receiver = input_receiver_fn()\n      self.assertEqual(set([\"x\", \"y\"]), set(input_receiver.features.keys()))\n      self.assertIsInstance(input_receiver.labels, tf.Tensor)\n      self.assertEqual(\n          set([\"x\", \"y\", \"label\"]), set(input_receiver.receiver_tensors.keys()))\n\n  def test_build_supervised_input_receiver_fn_from_input_fn_args(self):\n\n    def dummy_input_fn(feature_key=\"x\"):\n      return ({\n          feature_key: tf.constant([[1], [1]]),\n          \"y\": tf.constant([\"hello\", \"goodbye\"])\n      }, {\n          \"my_label\": tf.constant([[1], [1]])\n      })\n\n    input_receiver_fn = export.build_supervised_input_receiver_fn_from_input_fn(\n        dummy_input_fn, feature_key=\"z\")\n\n    with tf.Graph().as_default():\n      input_receiver = input_receiver_fn()\n      self.assertEqual(set([\"z\", \"y\"]), set(input_receiver.features.keys()))\n      self.assertEqual(set([\"my_label\"]), set(input_receiver.labels.keys()))\n      self.assertEqual(\n          set([\"z\", \"y\", \"my_label\"]),\n          set(input_receiver.receiver_tensors.keys()))\n\n\nclass TensorServingReceiverTest(tf.test.TestCase):\n\n  def test_tensor_serving_input_receiver_constructor(self):\n    features = tf.constant([0])\n    receiver_tensors = {\n        \"example0\": tf.constant([\"test0\"], name=\"example0\"),\n        u\"example1\": tf.constant([\"test1\"], name=\"example1\"),\n    }\n    r = export.TensorServingInputReceiver(features, receiver_tensors)\n    self.assertIsInstance(r.features, tf.Tensor)\n    self.assertIsInstance(r.receiver_tensors, dict)\n\n  def test_tensor_serving_input_receiver_sparse(self):\n    features = tf.sparse.SparseTensor(\n        indices=[[0, 0]], values=[1], dense_shape=[1, 1])\n    receiver_tensors = {\n        \"example0\": tf.constant([\"test0\"], name=\"example0\"),\n        u\"example1\": tf.constant([\"test1\"], name=\"example1\"),\n    }\n    r = export.TensorServingInputReceiver(features, receiver_tensors)\n    self.assertIsInstance(r.features, tf.sparse.SparseTensor)\n    self.assertIsInstance(r.receiver_tensors, dict)\n\n  def test_serving_input_receiver_features_invalid(self):\n    receiver_tensors = {\n        \"example0\": tf.constant([\"test0\"], name=\"example0\"),\n        u\"example1\": tf.constant([\"test1\"], name=\"example1\"),\n    }\n\n    with self.assertRaisesRegexp(ValueError, \"features must be defined\"):\n      export.TensorServingInputReceiver(\n          features=None, receiver_tensors=receiver_tensors)\n\n    with self.assertRaisesRegexp(ValueError, \"feature must be a Tensor\"):\n      export.TensorServingInputReceiver(\n          features={\"1\": tf.constant([1])}, receiver_tensors=receiver_tensors)\n\n  def test_serving_input_receiver_receiver_tensors_invalid(self):\n    features = tf.constant([0])\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"receiver_tensors must be defined\"):\n      export.TensorServingInputReceiver(\n          features=features, receiver_tensors=None)\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"receiver_tensor keys must be strings\"):\n      export.TensorServingInputReceiver(\n          features=features,\n          receiver_tensors={1: tf.constant([\"test\"], name=\"example0\")})\n\n    with self.assertRaisesRegexp(ValueError,\n                                 \"receiver_tensor example1 must be a Tensor\"):\n      export.TensorServingInputReceiver(\n          features=features, receiver_tensors={\"example1\": [1]})\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/export/function.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Defines class for wrapping an Estimator model function.\"\"\"\n# TODO(kathywu): support remaining outputs from the EstimatorSpec.\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.eager import function\nfrom tensorflow.python.eager import wrap_function\nfrom tensorflow.python.framework import func_graph\nfrom tensorflow.python.saved_model.model_utils import export_utils\nfrom tensorflow.python.util import function_utils\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\nclass ModelFunction(tf.compat.v2.__internal__.tracking.AutoTrackable):\n  \"\"\"A checkpointable ModelFunction object.\n\n  This object stores a global mapping of variables and functions for each mode.\n  \"\"\"\n\n  def __init__(self, config=None, params=None):\n    self._config = config\n    self._params = params\n    self._functions = {}\n\n    self._variable_holder = wrap_function.VariableHolder(share_variables=True)\n\n    # Add reference to the variable holder's mapping of variables, which is a\n    # trackable object.\n    self._variables_by_name = self._variable_holder.variables\n\n  @staticmethod\n  def from_function(model_fn, all_modes=None, config=None, params=None):\n    \"\"\"Creates a new ModelFunction object from a model function.\"\"\"\n    if all_modes is None:\n      all_modes = [ModeKeys.TRAIN, ModeKeys.EVAL, ModeKeys.PREDICT]\n    else:\n      all_modes = list(all_modes)\n\n    obj = ModelFunction(config=config, params=params)\n    for mode in all_modes:\n      obj.add_mode(model_fn, mode)\n    return obj\n\n  @property\n  def variables(self):\n    return self._variables_by_name\n\n  def add_mode(self, fn, mode, input_signature=None):\n    if mode in self._functions:\n      raise ValueError('ModelFunction object has multiple functions with name'\n                       ' {}.'.format(mode))\n\n    spec_fn = EstimatorSpecFunction(\n        fn,\n        mode,\n        config=self._config,\n        params=self._params,\n        variable_holder=self._variable_holder,\n        input_signature=input_signature)\n\n    self._functions[mode] = spec_fn\n\n  def train(self, features, labels):\n    return self.call(ModeKeys.TRAIN, features, labels)\n\n  def evaluate(self, features, labels):\n    return self.call(ModeKeys.EVAL, features, labels)\n\n  def predict(self, features):\n    return self.call(ModeKeys.PREDICT, features)\n\n  def call(self, mode, features, labels=None):\n    if mode not in self._functions:\n      raise ValueError(\n          'Mode {} is not defined the ModelFunction. To add modes,'\n          ' use the `add_mode()` function. Available modes: {}'.format(\n              mode, self._functions.keys()))\n    fn = self._functions[mode]\n    if fn.expects_labels:\n      return fn(features, labels)\n    else:\n      return fn(features)\n\n\ndef _wrap_and_verify_model_fn(model_fn,\n                              mode=None,\n                              config=None,\n                              params=None,\n                              input_signature=None):\n  \"\"\"Returns a function that only has only tensor arguments (features, labels).\n\n  Args:\n    model_fn: Model function. Must follow the signature defined in\n      `tf.estimator.Estimator`.\n    mode: Optional string `tf.estimstor.ModeKey`.\n    config: Optional `estimator.RunConfig` object.\n    params: Optional `dict` of hyperparameters.\n    input_signature: Possibly nested TensorSpec of the tensor arguments.\n\n  Returns:\n    tuple of (\n      function that only accepts tensor arguments (features and/or labels),\n      whether the returned function expects a labels argument)\n  \"\"\"\n  model_fn_lib.verify_model_fn_args(model_fn, params)\n  args = function_utils.fn_args(model_fn)\n  kwargs = {}\n  if 'mode' in args:\n    kwargs['mode'] = mode\n  if 'params' in args:\n    kwargs['params'] = params\n  if 'config' in args:\n    kwargs['config'] = config\n\n  if 'labels' in args:\n    if input_signature is None or len(input_signature) == 2:\n\n      def wrapped_model_fn(features, labels=None):\n        return model_fn(features=features, labels=labels, **kwargs)\n    else:\n\n      def wrapped_model_fn(features):\n        return model_fn(features=features, labels=None, **kwargs)\n  else:\n\n    def wrapped_model_fn(features):\n      return model_fn(features=features, **kwargs)\n\n  return wrapped_model_fn, 'labels' in args\n\n\nclass EstimatorSpecFunction(tf.compat.v2.__internal__.function.Function):\n  \"\"\"Wraps graph functions defined for a function returning an EstimatorSpec.\n\n  Instances of this class are revivable when attached to a checkpointable\n  object.\n  \"\"\"\n\n  def __init__(self,\n               fn,\n               mode,\n               config=None,\n               params=None,\n               variable_holder=None,\n               **kwargs):\n    \"\"\"Initializes an EstimatorSpecFunction.\n\n    Args:\n      fn: Python model function.\n      mode: String mode to run the function.\n      config: RunConfig that is passed to the `config` arg in the function.\n      params: object that is passed to the `params` argument in the function.\n      variable_holder: Optional `wrap_function.VariableHolder` object.\n      **kwargs: Optional keyword arguments to pass to tf.function (e.g.\n        input_signature).\n    \"\"\"\n    python_function, self.expects_labels = _wrap_and_verify_model_fn(\n        fn,\n        mode=mode,\n        config=config,\n        params=params,\n        input_signature=kwargs.get('input_signature', None))\n    super(EstimatorSpecFunction, self).__init__(python_function, mode, **kwargs)\n    self._variable_holder = variable_holder\n\n  def _defun(self, fn):\n    return _EstimatorSpecFunction(\n        fn,\n        name=self._name,\n        variable_holder=self._variable_holder,\n        input_signature=self.input_signature,\n        autograph=self._autograph,\n        autograph_options=self._experimental_autograph_options)\n\n\nclass _EstimatorSpecFunction(tf.compat.v2.__internal__.function.Function):\n  \"\"\"Wraps graph functions defined for a function returning an EstimatorSpec.\n\n  This object handles creation of the graph functions.\n  \"\"\"\n\n  def __init__(self, python_function, name, variable_holder=None, **kwargs):\n    super(_EstimatorSpecFunction, self).__init__(python_function, name,\n                                                 **kwargs)\n    self._variable_holder = variable_holder\n\n  def _create_graph_function(self, args, kwargs, **other_kwargs):\n    _ = other_kwargs\n    wrapped_graph = _EstimatorWrappedGraph(self._variable_holder)\n    return wrapped_graph.wrap_model_fn(\n        self._python_function,\n        self._name,\n        signature=self.input_signature,\n        args=args,\n        kwargs=kwargs)\n\n\nclass _EstimatorWrappedGraph(wrap_function.WrappedGraph):\n  \"\"\"WrappedGraph that handles global step creation and wraps estimator fns.\"\"\"\n\n  def __init__(self, *args, **kwargs):\n    super(_EstimatorWrappedGraph, self).__init__(*args, **kwargs)\n    # Create global step variable, which may be used by the input and model fns.\n    self._global_step_read_fn = self.wrap_function(\n        self._global_step, signature=[])\n\n    self._concrete_model_fn = None\n\n    # Original EstimatorSpec object returned by the model function. Only tensors\n    # and ops are returned by the concrete model function.\n    self._estimator_spec = None\n\n  def _global_step(self):\n    return tf.compat.v1.train.get_or_create_global_step()\n\n  @property\n  def global_step(self):\n    return self._global_step_read_fn()\n\n  @property\n  def model_fn(self):\n    return self._concrete_model_fn\n\n  @property\n  def estimator_spec(self):\n    if self._concrete_model_fn is None:\n      raise ValueError('Please wrap a model function first.')\n    return self._estimator_spec\n\n  def wrap_model_fn(self,\n                    model_fn,\n                    mode,\n                    args=None,\n                    kwargs=None,\n                    signature=None):\n    \"\"\"Wraps a model function, and stores the returned estimator spec.\"\"\"\n    if self._concrete_model_fn is not None:\n      raise ValueError('`wrap_model_fn` should be only called once per graph.')\n\n    def fn(*args, **kwargs):\n      \"\"\"Returns tensor and op outputs from the returned spec.\"\"\"\n      ret = model_fn(*args, **kwargs)\n\n      if isinstance(ret, model_fn_lib.EstimatorSpec):\n        self._estimator_spec = ret\n        return _filter_estimator_spec_outputs(ret)\n      return ret\n\n    name = 'model_fn_{}'.format(mode)\n    self._concrete_model_fn = self._wrap_function(fn, args, kwargs, signature,\n                                                  name)\n    return self._concrete_model_fn\n\n  def wrap_input_receiver_fn(self, input_receiver_fn):\n    \"\"\"Converts an input receiver function to one or more concrete functions.\n\n    Input receiver functions are python functions with no arguments.\n    Placeholders are created within the function and used to receive inputs to\n    the model.\n\n    The function (or multiple functions) generated depends on the InputReceiver\n    object returned by `input_receiver_fn`.\n\n    Generally, the returned function will have inputs and outputs:\n      input_receiver(**receiver_tensors) --> features\n\n    or (if the InputReceiver returns labels):\n      input_receiver(**receiver_tensors) --> features, labels\n\n    __Alternate Receiver Tensors__\n\n    The InputReceiver may have alternate receiver tensors, in which case\n    additional concrete functions are generated. Example:\n      InputReceiver.receiver_tensors_alternatives = {\n        'alt_input_1': Tensor,\n        'alt_input_2': {\n          'tensor_1': Tensor,\n          'tensor_2': Tensor\n        }\n      }\n\n    This will generate concrete functions:\n      input_receiver_alt_input_1(input) --> features\n      input_receiver_alt_input_2(tensor_1, tensor_2) --> features\n\n    Args:\n      input_receiver_fn: a no-argument function that returns an `InputReceiver`\n        object.\n\n    Returns:\n      A list of tuples of (concrete function, receiver name). The name of the\n      default input receiver is `None`.\n    \"\"\"\n    ret = [None]\n\n    def fn():\n      ret[0] = input_receiver = input_receiver_fn()\n      features = input_receiver.features\n      labels = getattr(input_receiver, 'labels', None)\n\n      if labels is None:\n        return features\n      return features, labels\n\n    func_graph.func_graph_from_py_func(\n        None,  # Name is unused.\n        self._variable_holder.call_with_variable_creator_scope(fn),\n        args=None,\n        kwargs=None,\n        signature=[],\n        add_control_dependencies=False,\n        func_graph=self.graph)\n\n    functions = []\n    input_receiver = ret[0]\n\n    wrapped_input_receiver_fn = _prune_receiver_tensors(\n        self._wrapped_function,\n        receiver_tensors=input_receiver.receiver_tensors,\n        outputs=self.graph.structured_outputs,\n        name=_input_receiver_fn_name(None))\n    functions.append((wrapped_input_receiver_fn, None))\n\n    receiver_tensors_alternatives = getattr(input_receiver,\n                                            'receiver_tensors_alternatives',\n                                            None)\n\n    if receiver_tensors_alternatives:\n      for receiver_name, receiver_tensors_alt in (\n          six.iteritems(receiver_tensors_alternatives)):\n        receiver_tensors_alt = _canonicalize_receiver_tensors(\n            receiver_tensors_alt)\n        wrapped_input_receiver_fn = _prune_receiver_tensors(\n            self._wrapped_function,\n            receiver_tensors=receiver_tensors_alt,\n            outputs=self.graph.structured_outputs,\n            name=_input_receiver_fn_name(receiver_name))\n        functions.append((wrapped_input_receiver_fn, receiver_name))\n    return functions\n\n\ndef _filter_estimator_spec_outputs(spec):\n  \"\"\"Filters tensors and ops from an EstimatorSpec and returns a dictionary.\"\"\"\n  # TODO(kathywu): Add loss, export outputs, eval metrics depending on the mode.\n  if spec.mode == ModeKeys.TRAIN:\n    return dict(predictions=spec.predictions, train_op=spec.train_op)\n  return dict(predictions=spec.predictions)\n\n\n_RECEIVER_FN_NAME = '_input_receiver'\n\n\ndef _canonicalize_receiver_tensors(receiver_tensors):\n  \"\"\"Converts receiver tensors to the expected format of `as_signature_def`.\"\"\"\n  # TODO(b/129646028): Wrap function doesn't support composite tensors.\n  for tensor in tf.nest.flatten(receiver_tensors):\n    if not isinstance(tensor, tf.Tensor):\n      raise ValueError('All receiver tensors must be tensors (composite '\n                       'tensors are not yet supported).')\n\n  if isinstance(receiver_tensors, dict):\n    return receiver_tensors\n  return {export_utils.SINGLE_RECEIVER_DEFAULT_NAME: receiver_tensors}\n\n\ndef _input_receiver_fn_name(name):\n  if name is None:\n    return _RECEIVER_FN_NAME\n  else:\n    return '{}_{}'.format(_RECEIVER_FN_NAME, name)\n\n\ndef _prune_receiver_tensors(wrapped_function, receiver_tensors, outputs, name):\n  inputs = _canonicalize_receiver_tensors(receiver_tensors)\n  return wrapped_function.prune(\n      inputs,\n      outputs,\n      name=name,\n      input_signature=(None, func_graph.convert_structure_to_signature(inputs)))\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/export/function_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for Estimator function objects.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nimport six as six\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.export import export_lib\nfrom tensorflow_estimator.python.estimator.export import function\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\ndef _string_fix(obj):\n  return tf.nest.map_structure(\n      lambda x: tf.compat.as_bytes(x)\n      if isinstance(x, six.string_types) else x, obj)\n\n\ndef _model_fn(features, labels, mode):\n  v = tf.Variable(tf.constant(23), name='v')\n  if mode == ModeKeys.PREDICT:\n    return model_fn_lib.EstimatorSpec(\n        ModeKeys.PREDICT, predictions=features + 1)\n  elif mode == ModeKeys.EVAL:\n    return model_fn_lib.EstimatorSpec(\n        ModeKeys.EVAL, loss=tf.constant(5) + v, predictions=features + labels)\n  elif mode == ModeKeys.TRAIN:\n    return model_fn_lib.EstimatorSpec(\n        ModeKeys.TRAIN,\n        predictions=features * labels,\n        loss=tf.constant(5) + v,\n        train_op=tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(),\n                                         1))\n\n\ndef _model_fn_train_only(features, labels):\n  v = tf.Variable(tf.constant(23), name='v')\n  return model_fn_lib.EstimatorSpec(\n      ModeKeys.TRAIN,\n      predictions=features * labels,\n      loss=tf.constant(5) + v,\n      train_op=tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(), 1))\n\n\ndef _model_fn_predict_only(features):\n  return model_fn_lib.EstimatorSpec(ModeKeys.PREDICT, predictions=features + 1)\n\n\n# TODO(kathywu): Re-enable test after def_function changes are built into\n# nightlies.\n@test_util.run_all_in_graph_and_eager_modes\nclass ModelFunctionTest(object):\n\n  def test_from_function(self):\n    mfn = function.ModelFunction.from_function(_model_fn)\n    out = mfn.train(tf.constant(3), tf.constant(5))\n\n    self.evaluate(tf.compat.v1.initializers.variables(mfn.variables.values()))\n\n    self.assertEqual(15, self.evaluate(out['predictions']))\n    out = mfn.evaluate(tf.constant(7), tf.constant(9))\n    self.assertEqual(16, self.evaluate(out['predictions']))\n    out = mfn.predict(tf.constant(10))\n    self.assertEqual(11, self.evaluate(out['predictions']))\n\n  def test_model_fn_train_only(self):\n    mfn = function.ModelFunction()\n    mfn.add_mode(_model_fn_train_only, ModeKeys.TRAIN)\n    out = mfn.train(tf.constant(4), tf.constant(6))\n\n    self.evaluate(tf.compat.v1.initializers.variables(mfn.variables.values()))\n\n    self.assertEqual(24, self.evaluate(out['predictions']))\n\n    with self.assertRaisesRegexp(ValueError, 'not defined'):\n      out = mfn.evaluate(tf.constant(7), tf.constant(9))\n\n  def test_model_fn_predict_only(self):\n    mfn = function.ModelFunction()\n    mfn.add_mode(_model_fn_predict_only, ModeKeys.PREDICT)\n    out = mfn.predict(tf.constant(4))\n\n    self.evaluate(tf.compat.v1.initializers.variables(mfn.variables.values()))\n\n    self.assertEqual(5, self.evaluate(out['predictions']))\n\n    with self.assertRaisesRegexp(ValueError, 'not defined'):\n      out = mfn.evaluate(tf.constant(7), tf.constant(9))\n\n  def test_save_and_load(self):\n    mfn = function.ModelFunction.from_function(_model_fn)\n\n    out = mfn.train(tf.constant(3), tf.constant(5))\n    self.evaluate(tf.compat.v1.initializers.variables(mfn.variables.values()))\n    self.evaluate(out['predictions'])\n\n    for _ in range(2):\n      out = mfn.train(tf.constant(3), tf.constant(5))\n      self.evaluate(out['predictions'])\n    self.assertEqual(\n        3, self.evaluate(mfn._variable_holder.variables['global_step']))\n\n    mfn.evaluate(tf.constant(7), tf.constant(9))\n    mfn.predict(tf.constant(10))\n\n    save_dir = os.path.join(self.get_temp_dir(), 'model_function')\n    tf.saved_model.save(mfn, save_dir)\n\n    obj = tf.saved_model.load(save_dir)\n    variables_by_name = obj._variables_by_name\n\n    self.evaluate(\n        tf.compat.v1.initializers.variables(\n            variables_by_name._unconditional_dependency_names.values()))\n    self.assertEqual(3, self.evaluate(variables_by_name.global_step))\n\n    out = obj._functions['train'](tf.constant(3), tf.constant(5))\n    self.assertEqual(15, self.evaluate(out['predictions']))\n    self.assertEqual(4, self.evaluate(variables_by_name.global_step))\n\n    out = obj._functions['eval'](tf.constant(7), tf.constant(9))\n    self.assertEqual(16, self.evaluate(out['predictions']))\n\n    out = obj._functions['infer'](tf.constant(10))\n    self.assertEqual(11, self.evaluate(out['predictions']))\n\n\ndef _model_fn_callable_variable_initializers(features, labels, mode):\n  \"\"\"Model_fn with callable variable initializers (for WrappedGraph tests).\"\"\"\n  _ = features, labels\n  v = tf.Variable(lambda: tf.constant(23), name='v')\n  if mode == ModeKeys.PREDICT:\n    return model_fn_lib.EstimatorSpec(\n        ModeKeys.PREDICT, predictions=features + 1)\n  elif mode == ModeKeys.EVAL:\n    return model_fn_lib.EstimatorSpec(\n        ModeKeys.EVAL, loss=tf.constant(5) + v, predictions=features + labels)\n  elif mode == ModeKeys.TRAIN:\n    return model_fn_lib.EstimatorSpec(\n        ModeKeys.TRAIN,\n        predictions=features * labels,\n        loss=tf.constant(5) + v,\n        train_op=tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(),\n                                         1))\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass EstimatorWrappedGraphTest(tf.test.TestCase):\n\n  def test_wrap_model_fn_train(self):\n    graph = function._EstimatorWrappedGraph()\n    features = tf.constant(3)\n    labels = tf.constant(4)\n    mode = ModeKeys.TRAIN\n    fn = graph.wrap_model_fn(\n        _model_fn_callable_variable_initializers,\n        mode=mode,\n        args=[features, labels, mode],\n        kwargs={})\n    self.evaluate(tf.compat.v1.initializers.variables(graph.variables.values()))\n    self.assertEqual(0, self.evaluate(graph.global_step))\n    self.assertEqual(12, self.evaluate(fn(features, labels)['predictions']))\n    self.assertEqual(1, self.evaluate(graph.global_step))\n\n    self.assertEqual('AssignAddVariableOp', graph.estimator_spec.train_op.type)\n\n  def test_wrap_model_fn_eval(self):\n    graph = function._EstimatorWrappedGraph()\n    features = tf.constant(5)\n    labels = tf.constant(6)\n    mode = ModeKeys.EVAL\n    fn = graph.wrap_model_fn(\n        _model_fn_callable_variable_initializers,\n        mode=mode,\n        args=[features, labels, mode],\n        kwargs={})\n    self.assertDictEqual({'predictions': 11},\n                         self.evaluate(fn(features, labels)))\n\n  def test_wrap_model_fn_predict(self):\n    graph = function._EstimatorWrappedGraph()\n    features = tf.constant(7)\n    mode = ModeKeys.PREDICT\n    fn = graph.wrap_model_fn(\n        _model_fn_callable_variable_initializers,\n        mode=mode,\n        args=[features, None, mode],\n        kwargs={})\n    self.assertDictEqual({'predictions': 8}, self.evaluate(fn(features)))\n\n  def test_wrap_input_receiver_fn(self):\n\n    def serving_input_fn():\n      receiver_1 = tf.compat.v1.placeholder(tf.dtypes.string)\n      receiver_2 = tf.compat.v1.placeholder(tf.dtypes.string)\n\n      receiver_tensors = {\n          'rec1': receiver_1,\n          u'rec2': receiver_2,\n      }\n\n      concat = tf.strings.join([receiver_1, receiver_2])\n      concat2 = tf.identity(concat)\n      features = {\n          'feature0': tf.strings.join([concat, concat2], ':'),\n          u'feature1': tf.constant([1])\n      }\n\n      alternate_tensors = {\n          'alt_name_1': concat,\n          'alt_name_2': {\n              'tensor1': concat,\n              'tensor2': concat2\n          }\n      }\n      return export_lib.ServingInputReceiver(features, receiver_tensors,\n                                             alternate_tensors)\n\n    graph = function._EstimatorWrappedGraph()\n    fns = graph.wrap_input_receiver_fn(serving_input_fn)\n\n    for fn, name in fns:\n      if name is None:\n        out = fn(tf.constant('1'), tf.constant('2'))\n        self.assertDictEqual(\n            _string_fix({\n                'feature0': '12:12',\n                'feature1': [1]\n            }), _string_fix(self.evaluate(out)))\n      elif name == 'alt_name_1':\n        out = fn(tf.constant('3'))\n        self.assertDictEqual(\n            _string_fix({\n                'feature0': '3:3',\n                'feature1': [1]\n            }), _string_fix(self.evaluate(out)))\n      elif name == 'alt_name_2':\n        out = fn(tf.constant('4'), tf.constant('5'))\n        self.assertDictEqual(\n            _string_fix({\n                'feature0': '4:5',\n                'feature1': [1]\n            }), _string_fix(self.evaluate(out)))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/exporter.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"`Exporter` class represents different flavors of model export.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\nimport os\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import gc\nfrom tensorflow_estimator.python.estimator import util\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\n\n@estimator_export('estimator.Exporter')\nclass Exporter(object):\n  \"\"\"A class representing a type of model export.\"\"\"\n\n  @abc.abstractproperty\n  def name(self):\n    \"\"\"Directory name.\n\n    A directory name under the export base directory where exports of\n    this type are written.  Should not be `None` nor empty.\n    \"\"\"\n    pass\n\n  @abc.abstractmethod\n  def export(self, estimator, export_path, checkpoint_path, eval_result,\n             is_the_final_export):\n    \"\"\"Exports the given `Estimator` to a specific format.\n\n    Args:\n      estimator: the `Estimator` to export.\n      export_path: A string containing a directory where to write the export.\n      checkpoint_path: The checkpoint path to export.\n      eval_result: The output of `Estimator.evaluate` on this checkpoint.\n      is_the_final_export: This boolean is True when this is an export in the\n        end of training.  It is False for the intermediate exports during the\n        training. When passing `Exporter` to `tf.estimator.train_and_evaluate`\n        `is_the_final_export` is always False if `TrainSpec.max_steps` is\n        `None`.\n\n    Returns:\n      The string path to the exported directory or `None` if export is skipped.\n    \"\"\"\n    pass\n\n\nclass _SavedModelExporter(Exporter):\n  \"\"\"This class exports the serving graph and checkpoints.\n\n     This class provides a basic exporting functionality and serves as a\n     foundation for specialized `Exporter`s.\n  \"\"\"\n\n  def __init__(self,\n               name,\n               serving_input_receiver_fn,\n               assets_extra=None,\n               as_text=False):\n    \"\"\"Create an `Exporter` to use with `tf.estimator.EvalSpec`.\n\n    Args:\n      name: unique name of this `Exporter` that is going to be used in the\n        export path.\n      serving_input_receiver_fn: a function that takes no arguments and returns\n        a `ServingInputReceiver`.\n      assets_extra: An optional dict specifying how to populate the assets.extra\n        directory within the exported SavedModel.  Each key should give the\n        destination path (including the filename) relative to the assets.extra\n        directory.  The corresponding value gives the full path of the source\n        file to be copied.  For example, the simple case of copying a single\n        file without renaming it is specified as\n        `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.\n      as_text: whether to write the SavedModel proto in text format. Defaults to\n        `False`.\n\n    Raises:\n      ValueError: if any arguments is invalid.\n    \"\"\"\n    self._name = name\n    self._serving_input_receiver_fn = serving_input_receiver_fn\n    self._assets_extra = assets_extra\n    self._as_text = as_text\n\n  @property\n  def name(self):\n    return self._name\n\n  def export(self, estimator, export_path, checkpoint_path, eval_result,\n             is_the_final_export):\n    del is_the_final_export\n\n    export_result = estimator.export_saved_model(\n        export_path,\n        self._serving_input_receiver_fn,\n        assets_extra=self._assets_extra,\n        as_text=self._as_text,\n        checkpoint_path=checkpoint_path)\n\n    return export_result\n\n\ndef _loss_smaller(best_eval_result, current_eval_result):\n  \"\"\"Compares two evaluation results and returns true if the 2nd one is smaller.\n\n  Both evaluation results should have the values for MetricKeys.LOSS, which are\n  used for comparison.\n\n  Args:\n    best_eval_result: best eval metrics.\n    current_eval_result: current eval metrics.\n\n  Returns:\n    True if the loss of current_eval_result is smaller; otherwise, False.\n\n  Raises:\n    ValueError: If input eval result is None or no loss is available.\n  \"\"\"\n  default_key = metric_keys.MetricKeys.LOSS\n  if not best_eval_result or default_key not in best_eval_result:\n    raise ValueError(\n        'best_eval_result cannot be empty or no loss is found in it.')\n\n  if not current_eval_result or default_key not in current_eval_result:\n    raise ValueError(\n        'current_eval_result cannot be empty or no loss is found in it.')\n\n  return best_eval_result[default_key] > current_eval_result[default_key]\n\n\ndef _verify_compare_fn_args(compare_fn):\n  \"\"\"Verifies compare_fn arguments.\"\"\"\n  args = set(util.fn_args(compare_fn))\n  if 'best_eval_result' not in args:\n    raise ValueError('compare_fn (%s) must include best_eval_result argument.' %\n                     compare_fn)\n  if 'current_eval_result' not in args:\n    raise ValueError(\n        'compare_fn (%s) must include current_eval_result argument.' %\n        compare_fn)\n  non_valid_args = list(args - set(['best_eval_result', 'current_eval_result']))\n  if non_valid_args:\n    raise ValueError('compare_fn (%s) has following not expected args: %s' %\n                     (compare_fn, non_valid_args))\n\n\n@estimator_export('estimator.BestExporter')\nclass BestExporter(Exporter):\n  \"\"\"This class exports the serving graph and checkpoints of the best models.\n\n  This class performs a model export everytime the new model is better than any\n  existing model.\n  \"\"\"\n\n  def __init__(self,\n               name='best_exporter',\n               serving_input_receiver_fn=None,\n               event_file_pattern='eval/*.tfevents.*',\n               compare_fn=_loss_smaller,\n               assets_extra=None,\n               as_text=False,\n               exports_to_keep=5):\n    \"\"\"Create an `Exporter` to use with `tf.estimator.EvalSpec`.\n\n    Example of creating a BestExporter for training and evaluation:\n\n    ```python\n    def make_train_and_eval_fn():\n      # Set up feature columns.\n      categorical_feature_a = (\n          tf.feature_column.categorical_column_with_hash_bucket(...))\n      categorical_feature_a_emb = embedding_column(\n          categorical_column=categorical_feature_a, ...)\n      ...  # other feature columns\n\n      estimator = tf.estimator.DNNClassifier(\n          config=tf.estimator.RunConfig(\n              model_dir='/my_model', save_summary_steps=100),\n          feature_columns=[categorical_feature_a_emb, ...],\n          hidden_units=[1024, 512, 256])\n\n      serving_feature_spec = tf.feature_column.make_parse_example_spec(\n          categorical_feature_a_emb)\n      serving_input_receiver_fn = (\n          tf.estimator.export.build_parsing_serving_input_receiver_fn(\n          serving_feature_spec))\n\n      exporter = tf.estimator.BestExporter(\n          name=\"best_exporter\",\n          serving_input_receiver_fn=serving_input_receiver_fn,\n          exports_to_keep=5)\n\n      train_spec = tf.estimator.TrainSpec(...)\n\n      eval_spec = [tf.estimator.EvalSpec(\n        input_fn=eval_input_fn,\n        steps=100,\n        exporters=exporter,\n        start_delay_secs=0,\n        throttle_secs=5)]\n\n      tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)\n\n    ```\n\n    Args:\n      name: unique name of this `Exporter` that is going to be used in the\n        export path.\n      serving_input_receiver_fn: a function that takes no arguments and returns\n        a `ServingInputReceiver`.\n      event_file_pattern: event file name pattern relative to model_dir. If\n        None, however, the exporter would not be preemption-safe. To be\n        preemption-safe, event_file_pattern must be specified.\n      compare_fn: a function that compares two evaluation results and returns\n        true if current evaluation result is better. Follows the signature:\n        * Args:\n          * `best_eval_result`: This is the evaluation result of the best model.\n          * `current_eval_result`: This is the evaluation result of current\n            candidate model.\n        * Returns: True if current evaluation result is better; otherwise,\n          False.\n      assets_extra: An optional dict specifying how to populate the assets.extra\n        directory within the exported SavedModel.  Each key should give the\n        destination path (including the filename) relative to the assets.extra\n        directory.  The corresponding value gives the full path of the source\n        file to be copied.  For example, the simple case of copying a single\n        file without renaming it is specified as `{'my_asset_file.txt':\n          '/path/to/my_asset_file.txt'}`.\n      as_text: whether to write the SavedModel proto in text format. Defaults to\n        `False`.\n      exports_to_keep: Number of exports to keep.  Older exports will be\n        garbage-collected.  Defaults to 5.  Set to `None` to disable garbage\n        collection.\n\n    Raises:\n      ValueError: if any argument is invalid.\n    \"\"\"\n    self._compare_fn = compare_fn\n    if self._compare_fn is None:\n      raise ValueError('`compare_fn` must not be None.')\n    _verify_compare_fn_args(self._compare_fn)\n\n    self._saved_model_exporter = _SavedModelExporter(name,\n                                                     serving_input_receiver_fn,\n                                                     assets_extra, as_text)\n\n    self._event_file_pattern = event_file_pattern\n    self._model_dir = None\n    self._best_eval_result = None\n    self._has_exported = False\n\n    self._exports_to_keep = exports_to_keep\n    if exports_to_keep is not None and exports_to_keep <= 0:\n      raise ValueError(\n          '`exports_to_keep`, if provided, must be a positive number. Got %s' %\n          exports_to_keep)\n\n  @property\n  def name(self):\n    return self._saved_model_exporter.name\n\n  def export(self, estimator, export_path, checkpoint_path, eval_result,\n             is_the_final_export):\n    export_result = None\n\n    if self._model_dir != estimator.model_dir and self._event_file_pattern:\n      # Loads best metric from event files.\n      tf.compat.v1.logging.info('Loading best metric from event files.')\n\n      self._model_dir = estimator.model_dir\n      full_event_file_pattern = os.path.join(self._model_dir,\n                                             self._event_file_pattern)\n      self._best_eval_result = self._get_best_eval_result(\n          full_event_file_pattern)\n\n    if (self._best_eval_result is None or\n        # check if this is the first export.\n        not self._has_exported or self._compare_fn(\n            best_eval_result=self._best_eval_result,\n            current_eval_result=eval_result)):\n      tf.compat.v1.logging.info('Performing best model export.')\n      self._best_eval_result = eval_result\n      export_result = self._saved_model_exporter.export(estimator, export_path,\n                                                        checkpoint_path,\n                                                        eval_result,\n                                                        is_the_final_export)\n      self._garbage_collect_exports(export_path)\n      self._has_exported = True\n\n    return export_result\n\n  def _garbage_collect_exports(self, export_dir_base):\n    \"\"\"Deletes older exports, retaining only a given number of the most recent.\n\n    Export subdirectories are assumed to be named with monotonically increasing\n    integers; the most recent are taken to be those with the largest values.\n\n    Args:\n      export_dir_base: the base directory under which each export is in a\n        versioned subdirectory.\n    \"\"\"\n    if self._exports_to_keep is None:\n      return\n\n    def _export_version_parser(path):\n      # create a simple parser that pulls the export_version from the directory.\n      filename = os.path.basename(path.path)\n      if not (len(filename) == 10 and filename.isdigit()):\n        return None\n      return path._replace(export_version=int(filename))\n\n    # pylint: disable=protected-access\n    keep_filter = gc._largest_export_versions(self._exports_to_keep)\n    delete_filter = gc._negation(keep_filter)\n    for p in delete_filter(\n        gc._get_paths(export_dir_base, parser=_export_version_parser)):\n      try:\n        tf.compat.v1.gfile.DeleteRecursively(p.path)\n      except tf.errors.NotFoundError as e:\n        tf.compat.v1.logging.warn('Can not delete %s recursively: %s', p.path,\n                                  e)\n    # pylint: enable=protected-access\n\n  def _get_best_eval_result(self, event_files):\n    \"\"\"Get the best eval result from event files.\n\n    Args:\n      event_files: Absolute pattern of event files.\n\n    Returns:\n      The best eval result.\n    \"\"\"\n    if not event_files:\n      return None\n\n    best_eval_result = None\n    for event_file in tf.compat.v1.gfile.Glob(os.path.join(event_files)):\n      for event in tf.compat.v1.train.summary_iterator(event_file):\n        if event.HasField('summary'):\n          event_eval_result = {}\n          for value in event.summary.value:\n            if value.HasField('simple_value'):\n              event_eval_result[value.tag] = value.simple_value\n          if event_eval_result:\n            if best_eval_result is None or self._compare_fn(\n                best_eval_result, event_eval_result):\n              best_eval_result = event_eval_result\n    return best_eval_result\n\n\n@estimator_export('estimator.FinalExporter')\nclass FinalExporter(Exporter):\n  \"\"\"This class exports the serving graph and checkpoints at the end.\n\n  This class performs a single export at the end of training.\n  \"\"\"\n\n  def __init__(self,\n               name,\n               serving_input_receiver_fn,\n               assets_extra=None,\n               as_text=False):\n    \"\"\"Create an `Exporter` to use with `tf.estimator.EvalSpec`.\n\n    Args:\n      name: unique name of this `Exporter` that is going to be used in the\n        export path.\n      serving_input_receiver_fn: a function that takes no arguments and returns\n        a `ServingInputReceiver`.\n      assets_extra: An optional dict specifying how to populate the assets.extra\n        directory within the exported SavedModel.  Each key should give the\n        destination path (including the filename) relative to the assets.extra\n        directory.  The corresponding value gives the full path of the source\n        file to be copied.  For example, the simple case of copying a single\n        file without renaming it is specified as\n        `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.\n      as_text: whether to write the SavedModel proto in text format. Defaults to\n        `False`.\n\n    Raises:\n      ValueError: if any arguments is invalid.\n    \"\"\"\n    self._saved_model_exporter = _SavedModelExporter(name,\n                                                     serving_input_receiver_fn,\n                                                     assets_extra, as_text)\n\n  @property\n  def name(self):\n    return self._saved_model_exporter.name\n\n  def export(self, estimator, export_path, checkpoint_path, eval_result,\n             is_the_final_export):\n    if not is_the_final_export:\n      return None\n\n    tf.compat.v1.logging.info(\n        'Performing the final export in the end of training.')\n\n    return self._saved_model_exporter.export(estimator, export_path,\n                                             checkpoint_path, eval_result,\n                                             is_the_final_export)\n\n\n@estimator_export('estimator.LatestExporter')\nclass LatestExporter(Exporter):\n  \"\"\"This class regularly exports the serving graph and checkpoints.\n\n  In addition to exporting, this class also garbage collects stale exports.\n  \"\"\"\n\n  def __init__(self,\n               name,\n               serving_input_receiver_fn,\n               assets_extra=None,\n               as_text=False,\n               exports_to_keep=5):\n    \"\"\"Create an `Exporter` to use with `tf.estimator.EvalSpec`.\n\n    Args:\n      name: unique name of this `Exporter` that is going to be used in the\n        export path.\n      serving_input_receiver_fn: a function that takes no arguments and returns\n        a `ServingInputReceiver`.\n      assets_extra: An optional dict specifying how to populate the assets.extra\n        directory within the exported SavedModel.  Each key should give the\n        destination path (including the filename) relative to the assets.extra\n        directory.  The corresponding value gives the full path of the source\n        file to be copied.  For example, the simple case of copying a single\n        file without renaming it is specified as\n        `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.\n      as_text: whether to write the SavedModel proto in text format. Defaults to\n        `False`.\n      exports_to_keep: Number of exports to keep.  Older exports will be\n        garbage-collected.  Defaults to 5.  Set to `None` to disable garbage\n        collection.\n\n    Raises:\n      ValueError: if any arguments is invalid.\n    \"\"\"\n    self._saved_model_exporter = _SavedModelExporter(name,\n                                                     serving_input_receiver_fn,\n                                                     assets_extra, as_text)\n    self._exports_to_keep = exports_to_keep\n    if exports_to_keep is not None and exports_to_keep <= 0:\n      raise ValueError(\n          '`exports_to_keep`, if provided, must be positive number')\n\n  @property\n  def name(self):\n    return self._saved_model_exporter.name\n\n  def export(self, estimator, export_path, checkpoint_path, eval_result,\n             is_the_final_export):\n    export_result = self._saved_model_exporter.export(estimator, export_path,\n                                                      checkpoint_path,\n                                                      eval_result,\n                                                      is_the_final_export)\n\n    self._garbage_collect_exports(export_path)\n    return export_result\n\n  def _garbage_collect_exports(self, export_dir_base):\n    \"\"\"Deletes older exports, retaining only a given number of the most recent.\n\n    Export subdirectories are assumed to be named with monotonically increasing\n    integers; the most recent are taken to be those with the largest values.\n\n    Args:\n      export_dir_base: the base directory under which each export is in a\n        versioned subdirectory.\n    \"\"\"\n    if self._exports_to_keep is None:\n      return\n\n    def _export_version_parser(path):\n      # create a simple parser that pulls the export_version from the directory.\n      filename = os.path.basename(path.path)\n      if not (len(filename) == 10 and filename.isdigit()):\n        return None\n      return path._replace(export_version=int(filename))\n\n    # pylint: disable=protected-access\n    keep_filter = gc._largest_export_versions(self._exports_to_keep)\n    delete_filter = gc._negation(keep_filter)\n    for p in delete_filter(\n        gc._get_paths(export_dir_base, parser=_export_version_parser)):\n      try:\n        tf.compat.v1.gfile.DeleteRecursively(p.path)\n      except tf.errors.NotFoundError as e:\n        tf.compat.v1.logging.warn('Can not delete %s recursively: %s', p.path,\n                                  e)\n    # pylint: enable=protected-access\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/exporter_test.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for `Exporter`s.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport tempfile\nimport time\nimport tensorflow as tf\nfrom tensorflow.python.eager import context\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.platform import gfile\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import exporter as exporter_lib\n\n\nclass BestExporterTest(tf.test.TestCase):\n\n  def test_error_out_if_exports_to_keep_is_zero(self):\n\n    def _serving_input_receiver_fn():\n      pass\n\n    with self.assertRaisesRegexp(ValueError, \"positive number\"):\n      exporter = exporter_lib.BestExporter(\n          name=\"best_exporter\",\n          serving_input_receiver_fn=_serving_input_receiver_fn,\n          exports_to_keep=0)\n      self.assertEqual(\"best_exporter\", exporter.name)\n\n  def test_best_exporter(self):\n\n    def _serving_input_receiver_fn():\n      pass\n\n    export_dir_base = tempfile.mkdtemp()\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/export\")\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/eval\")\n\n    exporter = exporter_lib.BestExporter(\n        name=\"best_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False,\n        exports_to_keep=5)\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    estimator.export_saved_model.return_value = \"export_result_path\"\n    estimator.model_dir = export_dir_base\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {}, False)\n\n    self.assertEqual(\"export_result_path\", export_result)\n    estimator.export_saved_model.assert_called_with(\n        export_dir_base,\n        _serving_input_receiver_fn,\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False,\n        checkpoint_path=\"checkpoint_path\")\n\n  def test_best_export_is_saved(self):\n\n    def _serving_input_receiver_fn():\n      pass\n\n    export_dir_base = tempfile.mkdtemp()\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/export\")\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/eval\")\n\n    exporter = exporter_lib.BestExporter(\n        name=\"best_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False,\n        exports_to_keep=1)\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    estimator.export_saved_model.return_value = \"export_result_path\"\n    estimator.model_dir = export_dir_base\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {\"loss\": 0.5}, False)\n\n    self.assertTrue(estimator.export_saved_model.called)\n    self.assertEqual(\"export_result_path\", export_result)\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {\"loss\": 0.6}, False)\n    self.assertEqual(None, export_result)\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {\"loss\": 0.4}, False)\n    self.assertEqual(\"export_result_path\", export_result)\n\n  def test_best_exporter_with_preemption(self):\n\n    def _serving_input_receiver_fn():\n      pass\n\n    export_dir_base = tempfile.mkdtemp()\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/export\")\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/eval\")\n\n    eval_dir_base = os.path.join(export_dir_base, \"eval_continuous\")\n    # _write_dict_to_summary is only called internally within graph mode.\n    with context.graph_mode():\n      estimator_lib._write_dict_to_summary(eval_dir_base, {\"loss\": 50}, 1)\n      estimator_lib._write_dict_to_summary(eval_dir_base, {\"loss\": 60}, 2)\n\n    exporter = exporter_lib.BestExporter(\n        name=\"best_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        event_file_pattern=\"eval_continuous/*.tfevents.*\",\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False,\n        exports_to_keep=1)\n\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    estimator.model_dir = export_dir_base\n    estimator.export_saved_model.return_value = \"export_result_path\"\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {\"loss\": 100}, False)\n    self.assertEqual(\"export_result_path\", export_result)\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {\"loss\": 10}, False)\n    self.assertEqual(\"export_result_path\", export_result)\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {\"loss\": 20}, False)\n    self.assertEqual(None, export_result)\n\n  @test_util.run_v1_only(\"Tests v1 only symbols\")\n  def test_best_exporter_with_empty_event(self):\n\n    def _serving_input_receiver_fn():\n      pass\n\n    export_dir_base = tempfile.mkdtemp()\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/export\")\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/eval\")\n\n    eval_dir_base = os.path.join(export_dir_base, \"eval_continuous\")\n    estimator_lib._write_dict_to_summary(eval_dir_base, {}, 1)\n    estimator_lib._write_dict_to_summary(eval_dir_base, {\"loss\": 60}, 2)\n\n    exporter = exporter_lib.BestExporter(\n        name=\"best_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        event_file_pattern=\"eval_continuous/*.tfevents.*\",\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False,\n        exports_to_keep=1)\n\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    estimator.model_dir = export_dir_base\n    estimator.export_saved_model.return_value = \"export_result_path\"\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {\"loss\": 100}, False)\n    self.assertEqual(\"export_result_path\", export_result)\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {\"loss\": 10}, False)\n    self.assertEqual(\"export_result_path\", export_result)\n\n  def test_the_first_export(self):\n\n    def _serving_input_receiver_fn():\n      pass\n\n    export_dir_base = tempfile.mkdtemp()\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/export\")\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/eval\")\n\n    exporter = exporter_lib.BestExporter(\n        name=\"best_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        event_file_pattern=\"eval_continuous/*.tfevents.*\",\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False,\n        exports_to_keep=1)\n\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    estimator.model_dir = export_dir_base\n    estimator.export_saved_model.return_value = \"export_result_path\"\n\n    # Note that evaluation occurs before export\n    with context.graph_mode():\n      eval_dir_base = os.path.join(export_dir_base, \"eval_continuous\")\n      first_evaluation_results = {\"loss\": 60}\n      estimator_lib._write_dict_to_summary(eval_dir_base,\n                                           first_evaluation_results, 1)\n\n    # export the model with the same results computed in the first evaluation\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", first_evaluation_results,\n                                    False)\n    self.assertEqual(\"export_result_path\", export_result)\n\n  def test_garbage_collect_exports(self):\n    export_dir_base = tempfile.mkdtemp()\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/export\")\n    tf.compat.v1.gfile.MkDir(export_dir_base + \"/eval\")\n\n    export_dir_1 = _create_test_export_dir(export_dir_base)\n    export_dir_2 = _create_test_export_dir(export_dir_base)\n    export_dir_3 = _create_test_export_dir(export_dir_base)\n    export_dir_4 = _create_test_export_dir(export_dir_base)\n\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_1))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_2))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_3))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_4))\n\n    def _serving_input_receiver_fn():\n      return tf.constant([1]), None\n\n    exporter = exporter_lib.BestExporter(\n        name=\"best_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        exports_to_keep=2)\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    estimator.model_dir = export_dir_base\n    # Garbage collect all but the most recent 2 exports,\n    # where recency is determined based on the timestamp directory names.\n    exporter.export(estimator, export_dir_base, None, None, False)\n\n    self.assertFalse(tf.compat.v1.gfile.Exists(export_dir_1))\n    self.assertFalse(tf.compat.v1.gfile.Exists(export_dir_2))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_3))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_4))\n\n\nclass LatestExporterTest(tf.test.TestCase):\n\n  def test_error_out_if_exports_to_keep_is_zero(self):\n\n    def _serving_input_receiver_fn():\n      pass\n\n    with self.assertRaisesRegexp(ValueError, \"positive number\"):\n      exporter = exporter_lib.LatestExporter(\n          name=\"latest_exporter\",\n          serving_input_receiver_fn=_serving_input_receiver_fn,\n          exports_to_keep=0)\n      self.assertEqual(\"latest_exporter\", exporter.name)\n\n  def test_latest_exporter(self):\n\n    def _serving_input_receiver_fn():\n      pass\n\n    export_dir_base = tempfile.mkdtemp() + \"export/\"\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n\n    exporter = exporter_lib.LatestExporter(\n        name=\"latest_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False,\n        exports_to_keep=5)\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    estimator.export_saved_model.return_value = \"export_result_path\"\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {}, False)\n\n    self.assertEqual(\"export_result_path\", export_result)\n    estimator.export_saved_model.assert_called_with(\n        export_dir_base,\n        _serving_input_receiver_fn,\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False,\n        checkpoint_path=\"checkpoint_path\")\n\n  def test_only_the_last_export_is_saved(self):\n\n    def _serving_input_receiver_fn():\n      pass\n\n    export_dir_base = tempfile.mkdtemp() + \"export/\"\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n\n    exporter = exporter_lib.FinalExporter(\n        name=\"latest_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False)\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    estimator.export_saved_model.return_value = \"export_result_path\"\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {}, False)\n\n    self.assertFalse(estimator.export_saved_model.called)\n    self.assertEqual(None, export_result)\n\n    export_result = exporter.export(estimator, export_dir_base,\n                                    \"checkpoint_path\", {}, True)\n\n    self.assertEqual(\"export_result_path\", export_result)\n    estimator.export_saved_model.assert_called_with(\n        export_dir_base,\n        _serving_input_receiver_fn,\n        assets_extra={\"from/path\": \"to/path\"},\n        as_text=False,\n        checkpoint_path=\"checkpoint_path\")\n\n  def test_garbage_collect_exports(self):\n    export_dir_base = tempfile.mkdtemp() + \"export/\"\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n    export_dir_1 = _create_test_export_dir(export_dir_base)\n    export_dir_2 = _create_test_export_dir(export_dir_base)\n    export_dir_3 = _create_test_export_dir(export_dir_base)\n    export_dir_4 = _create_test_export_dir(export_dir_base)\n\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_1))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_2))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_3))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_4))\n\n    def _serving_input_receiver_fn():\n      return tf.constant([1]), None\n\n    exporter = exporter_lib.LatestExporter(\n        name=\"latest_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        exports_to_keep=2)\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    # Garbage collect all but the most recent 2 exports,\n    # where recency is determined based on the timestamp directory names.\n    exporter.export(estimator, export_dir_base, None, None, False)\n\n    self.assertFalse(tf.compat.v1.gfile.Exists(export_dir_1))\n    self.assertFalse(tf.compat.v1.gfile.Exists(export_dir_2))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_3))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_4))\n\n  def test_garbage_collect_exports_with_trailing_delimiter(self):\n    export_dir_base = tempfile.mkdtemp() + \"export/\"\n    tf.compat.v1.gfile.MkDir(export_dir_base)\n    export_dir_1 = _create_test_export_dir(export_dir_base)\n    export_dir_2 = _create_test_export_dir(export_dir_base)\n    export_dir_3 = _create_test_export_dir(export_dir_base)\n    export_dir_4 = _create_test_export_dir(export_dir_base)\n\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_1))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_2))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_3))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_4))\n\n    def _serving_input_receiver_fn():\n      return tf.constant([1]), None\n\n    exporter = exporter_lib.LatestExporter(\n        name=\"latest_exporter\",\n        serving_input_receiver_fn=_serving_input_receiver_fn,\n        exports_to_keep=1)\n    estimator = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    # Garbage collect all but the most recent 2 exports,\n    # where recency is determined based on the timestamp directory names.\n    with tf.compat.v1.test.mock.patch.object(\n        gfile, \"ListDirectory\") as mock_list_directory:\n      mock_list_directory.return_value = [\n          os.path.basename(export_dir_1) + b\"/\",\n          os.path.basename(export_dir_2) + b\"/\",\n          os.path.basename(export_dir_3) + b\"/\",\n          os.path.basename(export_dir_4) + b\"/\",\n      ]\n      exporter.export(estimator, export_dir_base, None, None, False)\n\n    self.assertFalse(tf.compat.v1.gfile.Exists(export_dir_1))\n    self.assertFalse(tf.compat.v1.gfile.Exists(export_dir_2))\n    self.assertFalse(tf.compat.v1.gfile.Exists(export_dir_3))\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir_4))\n\n\ndef _create_test_export_dir(export_dir_base):\n  export_dir = _get_timestamped_export_dir(export_dir_base)\n  tf.compat.v1.gfile.MkDir(export_dir)\n  time.sleep(2)\n  return export_dir\n\n\ndef _get_timestamped_export_dir(export_dir_base):\n  # When we create a timestamped directory, there is a small chance that the\n  # directory already exists because another worker is also writing exports.\n  # In this case we just wait one second to get a new timestamp and try again.\n  # If this fails several times in a row, then something is seriously wrong.\n  max_directory_creation_attempts = 10\n\n  attempts = 0\n  while attempts < max_directory_creation_attempts:\n    export_timestamp = int(time.time())\n\n    export_dir = os.path.join(\n        tf.compat.as_bytes(export_dir_base),\n        tf.compat.as_bytes(str(export_timestamp)))\n    if not tf.compat.v1.gfile.Exists(export_dir):\n      # Collisions are still possible (though extremely unlikely): this\n      # directory is not actually created yet, but it will be almost\n      # instantly on return from this function.\n      return export_dir\n    time.sleep(1)\n    attempts += 1\n    tf.compat.v1.logging.warn(\n        \"Export directory {} already exists; retrying (attempt {}/{})\".format(\n            export_dir, attempts, max_directory_creation_attempts))\n  raise RuntimeError(\"Failed to obtain a unique export directory name after \"\n                     \"{} attempts.\".format(max_directory_creation_attempts))\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/extenders.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Extenders of tf.estimator.Estimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom tensorflow.python.util import function_utils\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n_VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config'])\n\n\n@estimator_export('estimator.add_metrics')\ndef add_metrics(estimator, metric_fn):\n  \"\"\"Creates a new `tf.estimator.Estimator` which has given metrics.\n\n  Example:\n\n  ```python\n    def my_auc(labels, predictions):\n      auc_metric = tf_keras.metrics.AUC(name=\"my_auc\")\n      auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])\n      return {'auc': auc_metric}\n\n    estimator = tf.estimator.DNNClassifier(...)\n    estimator = tf.estimator.add_metrics(estimator, my_auc)\n    estimator.train(...)\n    estimator.evaluate(...)\n  ```\n  Example usage of custom metric which uses features:\n\n  ```python\n    def my_auc(labels, predictions, features):\n      auc_metric = tf_keras.metrics.AUC(name=\"my_auc\")\n      auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],\n                              sample_weight=features['weight'])\n      return {'auc': auc_metric}\n\n    estimator = tf.estimator.DNNClassifier(...)\n    estimator = tf.estimator.add_metrics(estimator, my_auc)\n    estimator.train(...)\n    estimator.evaluate(...)\n  ```\n\n  Args:\n    estimator: A `tf.estimator.Estimator` object.\n    metric_fn: A function which should obey the following signature:\n      - Args: can only have following four arguments in any order:\n        * predictions: Predictions `Tensor` or dict of `Tensor` created by given\n          `estimator`.\n        * features: Input `dict` of `Tensor` objects created by `input_fn` which\n          is given to `estimator.evaluate` as an argument.\n        * labels:  Labels `Tensor` or dict of `Tensor` created by `input_fn`\n          which is given to `estimator.evaluate` as an argument.\n        * config: config attribute of the `estimator`.\n       - Returns: Dict of metric results keyed by name. Final metrics are a\n         union of this and `estimator's` existing metrics. If there is a name\n         conflict between this and `estimator`s existing metrics, this will\n         override the existing one. The values of the dict are the results of\n         calling a metric function, namely a `(metric_tensor, update_op)` tuple.\n\n  Returns:\n      A new `tf.estimator.Estimator` which has a union of original metrics with\n        given ones.\n  \"\"\"\n  _verify_metric_fn_args(metric_fn)\n\n  def new_model_fn(features, labels, mode, config):\n    spec = estimator.model_fn(features, labels, mode, config)\n    if mode != ModeKeys.EVAL:\n      return spec\n    new_metrics = _call_metric_fn(metric_fn, features, labels, spec.predictions,\n                                  config)\n    all_metrics = spec.eval_metric_ops or {}\n    all_metrics.update(new_metrics)\n    return spec._replace(eval_metric_ops=all_metrics)\n\n  return estimator_lib.Estimator(\n      model_fn=new_model_fn,\n      model_dir=estimator.model_dir,\n      config=estimator.config,\n      # pylint: disable=protected-access\n      warm_start_from=estimator._warm_start_settings)\n  # pylint: enable=protected-access\n\n\ndef _verify_metric_fn_args(metric_fn):\n  args = set(function_utils.fn_args(metric_fn))\n  invalid_args = list(args - _VALID_METRIC_FN_ARGS)\n  if invalid_args:\n    raise ValueError('metric_fn (%s) has following not expected args: %s' %\n                     (metric_fn, invalid_args))\n\n\ndef _call_metric_fn(metric_fn, features, labels, predictions, config):\n  \"\"\"Calls metric fn with proper arguments.\"\"\"\n  metric_fn_args = function_utils.fn_args(metric_fn)\n  kwargs = {}\n  if 'features' in metric_fn_args:\n    kwargs['features'] = features\n  if 'labels' in metric_fn_args:\n    kwargs['labels'] = labels\n  if 'predictions' in metric_fn_args:\n    kwargs['predictions'] = predictions\n  if 'config' in metric_fn_args:\n    kwargs['config'] = config\n  return metric_fn(**kwargs)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/extenders_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"extenders tests.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import extenders\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator import run_config\nfrom tensorflow_estimator.python.estimator.canned import linear\n\n\ndef get_input_fn(x, y):\n\n  def input_fn():\n    dataset = tf.compat.v1.data.Dataset.from_tensor_slices({'x': x, 'y': y})\n    iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)\n    features = iterator.get_next()\n    labels = features.pop('y')\n    return features, labels\n\n  return input_fn\n\n\nclass AddMetricsTest(tf.test.TestCase):\n\n  def test_should_add_metrics(self):\n\n    def _test_metric_fn(metric_fn):\n      input_fn = get_input_fn(\n          x=np.arange(4)[:, None, None], y=np.ones(4)[:, None])\n      config = run_config.RunConfig(log_step_count_steps=1)\n      estimator = linear.LinearClassifierV2(\n          [tf.feature_column.numeric_column('x')], config=config)\n\n      estimator = extenders.add_metrics(estimator, metric_fn)\n\n      estimator.train(input_fn=input_fn)\n      metrics = estimator.evaluate(input_fn=input_fn)\n      self.assertIn('mean_x', metrics)\n      self.assertEqual(1.5, metrics['mean_x'])\n      # assert that it keeps original estimators metrics\n      self.assertIn('auc', metrics)\n\n    def metric_fn(features):\n      metric = tf_keras.metrics.Mean()\n      metric.update_state(features['x'])\n      return {'mean_x': metric}\n\n    _test_metric_fn(metric_fn)\n\n  def test_should_error_out_for_not_recognized_args(self):\n    estimator = linear.LinearClassifierV2(\n        [tf.feature_column.numeric_column('x')])\n\n    def metric_fn(features, not_recognized):\n      _, _ = features, not_recognized\n      return {}\n\n    with self.assertRaisesRegexp(ValueError, 'not_recognized'):\n      estimator = extenders.add_metrics(estimator, metric_fn)\n\n  def test_all_supported_args(self):\n    input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])\n    estimator = linear.LinearClassifierV2(\n        [tf.feature_column.numeric_column('x')])\n\n    def metric_fn(features, predictions, labels, config):\n      self.assertIn('x', features)\n      self.assertIsNotNone(labels)\n      self.assertIn('logistic', predictions)\n      self.assertTrue(isinstance(config, run_config.RunConfig))\n      return {}\n\n    estimator = extenders.add_metrics(estimator, metric_fn)\n\n    estimator.train(input_fn=input_fn)\n    estimator.evaluate(input_fn=input_fn)\n\n  def test_all_supported_args_in_different_order(self):\n    input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])\n    estimator = linear.LinearClassifierV2(\n        [tf.feature_column.numeric_column('x')])\n\n    def metric_fn(labels, config, features, predictions):\n      self.assertIn('x', features)\n      self.assertIsNotNone(labels)\n      self.assertIn('logistic', predictions)\n      self.assertTrue(isinstance(config, run_config.RunConfig))\n      return {}\n\n    estimator = extenders.add_metrics(estimator, metric_fn)\n\n    estimator.train(input_fn=input_fn)\n    estimator.evaluate(input_fn=input_fn)\n\n  def test_all_args_are_optional(self):\n\n    def _test_metric_fn(metric_fn):\n      input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])\n      estimator = linear.LinearClassifierV2(\n          [tf.feature_column.numeric_column('x')])\n      estimator = extenders.add_metrics(estimator, metric_fn)\n\n      estimator.train(input_fn=input_fn)\n      metrics = estimator.evaluate(input_fn=input_fn)\n      self.assertEqual(2., metrics['two'])\n\n    def metric_fn():\n      metric = tf_keras.metrics.Mean()\n      metric.update_state(tf.constant([2.]))\n      return {'two': metric}\n\n    _test_metric_fn(metric_fn)\n\n  def test_overrides_existing_metrics(self):\n\n    def _test_metric_fn(metric_fn):\n      input_fn = get_input_fn(x=[[[0.]]], y=[[[1]]])\n      estimator = linear.LinearClassifierV2(\n          [tf.feature_column.numeric_column('x')])\n      estimator.train(input_fn=input_fn)\n      metrics = estimator.evaluate(input_fn=input_fn)\n      self.assertNotEqual(2., metrics['auc'])\n\n      estimator = extenders.add_metrics(estimator, metric_fn)\n      metrics = estimator.evaluate(input_fn=input_fn)\n      self.assertEqual(2., metrics['auc'])\n\n    def metric_fn():\n      metric = tf_keras.metrics.Mean()\n      metric.update_state(tf.constant([2.]))\n      return {'auc': metric}\n\n    _test_metric_fn(metric_fn)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/gc.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\nr\"\"\"System for specifying garbage collection (GC) of path based data.\n\nThis framework allows for GC of data specified by path names, for example files\non disk.  gc.Path objects each represent a single item stored at a path and may\nbe a base directory,\n  /tmp/exports/0/...\n  /tmp/exports/1/...\n  ...\nor a fully qualified file,\n  /tmp/train-1.ckpt\n  /tmp/train-2.ckpt\n  ...\n\nA gc filter function takes and returns a list of gc.Path items.  Filter\nfunctions are responsible for selecting Path items for preservation or deletion.\nNote that functions should always return a sorted list.\n\nFor example,\n  base_dir = \"/tmp\"\n  # Create the directories.\n  for e in xrange(10):\n    os.mkdir(\"%s/%d\" % (base_dir, e), 0o755)\n\n  # Create a simple parser that pulls the export_version from the directory.\n  path_regex = \"^\" + re.escape(base_dir) + \"/(\\\\d+)$\"\n  def parser(path):\n    match = re.match(path_regex, path.path)\n    if not match:\n      return None\n    return path._replace(export_version=int(match.group(1)))\n\n  path_list = gc._get_paths(\"/tmp\", parser)  # contains all ten Paths\n\n  every_fifth = gc._mod_export_version(5)\n  print(every_fifth(path_list))  # shows [\"/tmp/0\", \"/tmp/5\"]\n\n  largest_three = gc.largest_export_versions(3)\n  print(largest_three(all_paths))  # shows [\"/tmp/7\", \"/tmp/8\", \"/tmp/9\"]\n\n  both = gc._union(every_fifth, largest_three)\n  print(both(all_paths))  # shows [\"/tmp/0\", \"/tmp/5\",\n                          #        \"/tmp/7\", \"/tmp/8\", \"/tmp/9\"]\n  # Delete everything not in 'both'.\n  to_delete = gc._negation(both)\n  for p in to_delete(all_paths):\n    gfile.DeleteRecursively(p.path)  # deletes:  \"/tmp/1\", \"/tmp/2\",\n                                     # \"/tmp/3\", \"/tmp/4\", \"/tmp/6\",\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nimport collections\nimport heapq\nimport math\nimport os\nimport tensorflow as tf\nfrom tensorflow.python.platform import gfile\n\nPath = collections.namedtuple('Path', 'path export_version')\n\n\ndef _largest_export_versions(n):\n  \"\"\"Creates a filter that keeps the largest n export versions.\n\n  Args:\n    n: number of versions to keep.\n\n  Returns:\n    A filter function that keeps the n largest paths.\n  \"\"\"\n\n  def keep(paths):\n    heap = []\n    for idx, path in enumerate(paths):\n      if path.export_version is not None:\n        heapq.heappush(heap, (path.export_version, idx))\n    keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]\n    return sorted(keepers)\n\n  return keep\n\n\ndef _one_of_every_n_export_versions(n):\n  \"\"\"Creates a filter that keeps one of every n export versions.\n\n  Args:\n    n: interval size.\n\n  Returns:\n    A filter function that keeps exactly one path from each interval\n    [0, n], (n, 2n], (2n, 3n], etc...  If more than one path exists in an\n    interval the largest is kept.\n  \"\"\"\n\n  def keep(paths):\n    \"\"\"A filter function that keeps exactly one out of every n paths.\"\"\"\n\n    keeper_map = {}  # map from interval to largest path seen in that interval\n    for p in paths:\n      if p.export_version is None:\n        # Skip missing export_versions.\n        continue\n      # Find the interval (with a special case to map export_version = 0 to\n      # interval 0.\n      interval = math.floor(\n          (p.export_version - 1) / n) if p.export_version else 0\n      existing = keeper_map.get(interval, None)\n      if (not existing) or (existing.export_version < p.export_version):\n        keeper_map[interval] = p\n    return sorted(keeper_map.values())\n\n  return keep\n\n\ndef _mod_export_version(n):\n  \"\"\"Creates a filter that keeps every export that is a multiple of n.\n\n  Args:\n    n: step size.\n\n  Returns:\n    A filter function that keeps paths where export_version % n == 0.\n  \"\"\"\n\n  def keep(paths):\n    keepers = []\n    for p in paths:\n      if p.export_version % n == 0:\n        keepers.append(p)\n    return sorted(keepers)\n\n  return keep\n\n\ndef _union(lf, rf):\n  \"\"\"Creates a filter that keeps the union of two filters.\n\n  Args:\n    lf: first filter\n    rf: second filter\n\n  Returns:\n    A filter function that keeps the n largest paths.\n  \"\"\"\n\n  def keep(paths):\n    l = set(lf(paths))\n    r = set(rf(paths))\n    return sorted(list(l | r))\n\n  return keep\n\n\ndef _negation(f):\n  \"\"\"Negate a filter.\n\n  Args:\n    f: filter function to invert\n\n  Returns:\n    A filter function that returns the negation of f.\n  \"\"\"\n\n  def keep(paths):\n    l = set(paths)\n    r = set(f(paths))\n    return sorted(list(l - r))\n\n  return keep\n\n\ndef _get_paths(base_dir, parser):\n  \"\"\"Gets a list of Paths in a given directory.\n\n  Args:\n    base_dir: directory.\n    parser: a function which gets the raw Path and can augment it with\n      information such as the export_version, or ignore the path by returning\n      None.  An example parser may extract the export version from a path such\n      as \"/tmp/exports/100\" an another may extract from a full file name such as\n      \"/tmp/checkpoint-99.out\".\n\n  Returns:\n    A list of Paths contained in the base directory with the parsing function\n    applied.\n    By default the following fields are populated,\n      - Path.path\n    The parsing function is responsible for populating,\n      - Path.export_version\n  \"\"\"\n  # We are mocking this in the test, hence we should not use public API\n  raw_paths = gfile.ListDirectory(base_dir)\n  paths = []\n  for r in raw_paths:\n    # ListDirectory() return paths with \"/\" at the last if base_dir was GCS URL\n    r = tf.compat.as_str_any(r)\n    if r[-1] == '/':\n      r = r[0:len(r) - 1]\n    p = parser(Path(os.path.join(tf.compat.as_str_any(base_dir), r), None))\n    if p:\n      paths.append(p)\n  return sorted(paths)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/gc_test.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for garbage collection utilities.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport re\n\nfrom six.moves import xrange  # pylint: disable=redefined-builtin\nimport tensorflow as tf\nfrom tensorflow.python.platform import gfile\nfrom tensorflow_estimator.python.estimator import gc\n\n\ndef _create_parser(base_dir):\n  # create a simple parser that pulls the export_version from the directory.\n  def parser(path):\n    # Modify the path object for RegEx match for Windows Paths\n    if os.name == \"nt\":\n      match = re.match(\n          \"^\" + tf.compat.as_str_any(base_dir).replace(\"\\\\\", \"/\") + \"/(\\\\d+)$\",\n          tf.compat.as_str_any(path.path).replace(\"\\\\\", \"/\"))\n    else:\n      match = re.match(\"^\" + tf.compat.as_str_any(base_dir) + \"/(\\\\d+)$\",\n                       tf.compat.as_str_any(path.path))\n    if not match:\n      return None\n    return path._replace(export_version=int(match.group(1)))\n\n  return parser\n\n\nclass GcTest(tf.test.TestCase):\n\n  def testLargestExportVersions(self):\n    paths = [gc.Path(\"/foo\", 8), gc.Path(\"/foo\", 9), gc.Path(\"/foo\", 10)]\n    newest = gc._largest_export_versions(2)\n    n = newest(paths)\n    self.assertEqual(n, [gc.Path(\"/foo\", 9), gc.Path(\"/foo\", 10)])\n\n  def testLargestExportVersionsDoesNotDeleteZeroFolder(self):\n    paths = [gc.Path(\"/foo\", 0), gc.Path(\"/foo\", 3)]\n    newest = gc._largest_export_versions(2)\n    n = newest(paths)\n    self.assertEqual(n, [gc.Path(\"/foo\", 0), gc.Path(\"/foo\", 3)])\n\n  def testModExportVersion(self):\n    paths = [\n        gc.Path(\"/foo\", 4),\n        gc.Path(\"/foo\", 5),\n        gc.Path(\"/foo\", 6),\n        gc.Path(\"/foo\", 9)\n    ]\n    mod = gc._mod_export_version(2)\n    self.assertEqual(mod(paths), [gc.Path(\"/foo\", 4), gc.Path(\"/foo\", 6)])\n    mod = gc._mod_export_version(3)\n    self.assertEqual(mod(paths), [gc.Path(\"/foo\", 6), gc.Path(\"/foo\", 9)])\n\n  def testOneOfEveryNExportVersions(self):\n    paths = [\n        gc.Path(\"/foo\", 0),\n        gc.Path(\"/foo\", 1),\n        gc.Path(\"/foo\", 3),\n        gc.Path(\"/foo\", 5),\n        gc.Path(\"/foo\", 6),\n        gc.Path(\"/foo\", 7),\n        gc.Path(\"/foo\", 8),\n        gc.Path(\"/foo\", 33)\n    ]\n    one_of = gc._one_of_every_n_export_versions(3)\n    self.assertEqual(\n        one_of(paths), [\n            gc.Path(\"/foo\", 3),\n            gc.Path(\"/foo\", 6),\n            gc.Path(\"/foo\", 8),\n            gc.Path(\"/foo\", 33)\n        ])\n\n  def testOneOfEveryNExportVersionsZero(self):\n    # Zero is a special case since it gets rolled into the first interval.\n    # Test that here.\n    paths = [gc.Path(\"/foo\", 0), gc.Path(\"/foo\", 4), gc.Path(\"/foo\", 5)]\n    one_of = gc._one_of_every_n_export_versions(3)\n    self.assertEqual(one_of(paths), [gc.Path(\"/foo\", 0), gc.Path(\"/foo\", 5)])\n\n  def testUnion(self):\n    paths = []\n    for i in xrange(10):\n      paths.append(gc.Path(\"/foo\", i))\n    f = gc._union(gc._largest_export_versions(3), gc._mod_export_version(3))\n    self.assertEqual(\n        f(paths), [\n            gc.Path(\"/foo\", 0),\n            gc.Path(\"/foo\", 3),\n            gc.Path(\"/foo\", 6),\n            gc.Path(\"/foo\", 7),\n            gc.Path(\"/foo\", 8),\n            gc.Path(\"/foo\", 9)\n        ])\n\n  def testNegation(self):\n    paths = [\n        gc.Path(\"/foo\", 4),\n        gc.Path(\"/foo\", 5),\n        gc.Path(\"/foo\", 6),\n        gc.Path(\"/foo\", 9)\n    ]\n    mod = gc._negation(gc._mod_export_version(2))\n    self.assertEqual(mod(paths), [gc.Path(\"/foo\", 5), gc.Path(\"/foo\", 9)])\n    mod = gc._negation(gc._mod_export_version(3))\n    self.assertEqual(mod(paths), [gc.Path(\"/foo\", 4), gc.Path(\"/foo\", 5)])\n\n  def testPathsWithParse(self):\n    base_dir = os.path.join(tf.compat.v1.test.get_temp_dir(), \"paths_parse\")\n    self.assertFalse(tf.compat.v1.gfile.Exists(base_dir))\n    for p in xrange(3):\n      tf.compat.v1.gfile.MakeDirs(os.path.join(base_dir, \"%d\" % p))\n    # add a base_directory to ignore\n    tf.compat.v1.gfile.MakeDirs(os.path.join(base_dir, \"ignore\"))\n\n    self.assertEqual(\n        gc._get_paths(base_dir, _create_parser(base_dir)), [\n            gc.Path(os.path.join(base_dir, \"0\"), 0),\n            gc.Path(os.path.join(base_dir, \"1\"), 1),\n            gc.Path(os.path.join(base_dir, \"2\"), 2)\n        ])\n    tf.compat.v1.gfile.DeleteRecursively(base_dir)\n\n  def testMixedStrTypes(self):\n    temp_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())\n\n    for sub_dir in [\"str\", b\"bytes\", u\"unicode\"]:\n      base_dir = os.path.join(\n          (temp_dir if isinstance(sub_dir, bytes) else temp_dir.decode()),\n          sub_dir)\n      self.assertFalse(tf.compat.v1.gfile.Exists(base_dir))\n      tf.compat.v1.gfile.MakeDirs(\n          os.path.join(tf.compat.as_str_any(base_dir), \"42\"))\n      gc._get_paths(base_dir, _create_parser(base_dir))\n      tf.compat.v1.gfile.DeleteRecursively(base_dir)\n\n  def testGcsDirWithSeparator(self):\n    base_dir = \"gs://bucket/foo\"\n    with tf.compat.v1.test.mock.patch.object(\n        gfile, \"ListDirectory\") as mock_list_directory:\n      # gfile.ListDirectory returns directory names with separator '/'\n      mock_list_directory.return_value = [\"0/\", \"1/\"]\n      self.assertEqual(\n          gc._get_paths(base_dir, _create_parser(base_dir)), [\n              gc.Path(os.path.join(base_dir, \"0\"), 0),\n              gc.Path(os.path.join(base_dir, \"1\"), 1)\n          ])\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/__init__.py",
    "content": ""
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/base_head.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Abstractions for the base head class.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column_lib\nfrom tensorflow.python.feature_column.feature_column import _LazyBuilder\nfrom tensorflow.python.feature_column.feature_column import _NumericColumn\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.util import function_utils\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.export import export_output\n\nDEFAULT_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n\n# The above default is defined by TF Serving, but these next three are just\n# a local convention without any special meaning.\nCLASSIFY_SERVING_KEY = 'classification'\nREGRESS_SERVING_KEY = 'regression'\nPREDICT_SERVING_KEY = 'predict'\n\n\n@estimator_export('estimator.Head')\n@six.add_metaclass(abc.ABCMeta)\nclass Head(object):\n  \"\"\"Interface for the head/top of a model.\n\n  Head sits on top of the model network and handles computing the outputs of\n  the network. Given logits (or output of a hidden layer), a Head knows how to\n  compute predictions, loss, train_op, metrics and export outputs. It is meant\n  to:\n\n  1. Simplify writing model_fn and to make model_fn more configurable for\n     Estimator.\n  2. Simpilfy creating loss and metrics for the train and test loop in Eager\n     execution.\n  3. Support wide range of machine learning models. Since most heads can work\n     with logits, they can support DNN, RNN, Wide, Wide&Deep,\n     Global objectives, Gradient boosted trees and many other types\n     of machine learning models.\n\n  Common usage:\n  Here is simplified model_fn to build a DNN regression model.\n    ```python\n    def _my_dnn_model_fn(features, labels, mode, params, config=None):\n      # Optionally your callers can pass head to model_fn as a param.\n      head = tf.estimator.RegressionHead(...)\n\n      feature_columns = tf.feature_column.numeric_column(...)\n      feature_layer = tf_keras.layers.DenseFeatures(feature_columns)\n      inputs = feature_layer(features)\n\n      # Compute logits with tf_keras.layers API\n      hidden_layer0 = tf_keras.layers.Dense(\n          units=1000, activation=\"relu\")(inputs)\n      hidden_layer1 = tf_keras.layers.Dense(\n          units=500, activation=\"relu\")(hidden_layer0)\n      logits = tf_keras.layers.Dense(\n          units=head.logits_dimension, activation=None)(hidden_layer1)\n\n      # Or use Keras model for logits computation\n      model = tf_keras.Sequential()\n      model.add(tf_keras.layers.Dense(units=1000, activation=\"relu\"))\n      model.add(tf_keras.layers.Dense(units=500, activation=\"relu\"))\n      model.add(tf_keras.layers.Dense(\n         units=head.logits_dimension, activation=None))\n      logits = model(inputs)\n\n      return head.create_estimator_spec(\n          features=features,\n          labels=labels,\n          mode=mode,\n          logits=logits,\n          optimizer=optimizer)\n    ```\n  \"\"\"\n\n  @abc.abstractproperty\n  def name(self):\n    \"\"\"The name of this head.\n\n    Returns:\n      A string.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  @abc.abstractproperty\n  def logits_dimension(self):\n    \"\"\"Size of the last dimension of the logits `Tensor`.\n\n    Often is the number of classes, labels, or real values to be predicted.\n    Typically, logits is of shape `[batch_size, logits_dimension]`.\n\n    Returns:\n      The expected size of the `logits` tensor.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  @abc.abstractproperty\n  def loss_reduction(self):\n    \"\"\"One of `tf.losses.Reduction`.\n\n    Describes how to reduce training loss over batch, such as mean or sum.\n\n    Returns:\n      The type of loss reduction used in the head.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  @abc.abstractmethod\n  def loss(self,\n           labels,\n           logits,\n           features=None,\n           mode=None,\n           regularization_losses=None):\n    \"\"\"Returns a loss `Tensor` from provided arguments.\n\n    Note that, the args of `features` and `mode` are most likely not used, but\n    some Head implementations may require them.\n\n    Args:\n      labels: Labels `Tensor`, or `dict` mapping string label names to `Tensor`\n        objects of the label values.\n      logits: Logits `Tensor` to be used for loss construction.\n      features: Input `dict` mapping string feature names to `Tensor` or\n        `SparseTensor` objects containing the values for that feature in a\n        minibatch. Often to be used to fetch example-weight tensor.\n      mode: Estimator's `ModeKeys`. To be used in case loss calculation is\n        different in Train and Eval mode.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses.\n\n    Returns:\n      A scalar `Tensor` representing regularized training loss used in train and\n      eval.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  @abc.abstractmethod\n  def predictions(self, logits, keys=None):\n    \"\"\"Returns a `dict` of predictions from provided logits.\n\n    Args:\n      logits: Logits `Tensor` to be used for prediction construction.\n      keys: A list of `string` for prediction keys. Defaults to `None`, meaning\n        if not specified, predictions will be created for all the pre-defined\n        valid keys in the head.\n\n    Returns:\n      A `dict` of predicted `Tensor` keyed by prediction name.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  @abc.abstractmethod\n  def metrics(self, regularization_losses=None):\n    \"\"\"Returns a `dict` of metric objects.\n\n    Args:\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses.\n\n    Returns:\n       A `dict` of metrics keyed by string name. The value is an instance of\n       `Metric` class.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  @abc.abstractmethod\n  def update_metrics(self,\n                     eval_metrics,\n                     features,\n                     logits,\n                     labels,\n                     mode=None,\n                     regularization_losses=None):\n    \"\"\"Updates metric objects and returns a `dict` of the updated metrics.\n\n    Args:\n      eval_metrics: A `dict` of metrics to be updated.\n      features: Input `dict` mapping string feature names to `Tensor` or\n        `SparseTensor` objects containing the values for that feature in a\n        minibatch. Often to be used to fetch example-weight tensor.\n      logits: logits `Tensor` to be used for metrics update.\n      labels: Labels `Tensor`, or `dict` mapping string label names to `Tensor`\n        objects of the label values.\n      mode: Estimator's `ModeKeys`. In most cases, this arg is not used and can\n        be removed in the method implementation.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training and evaluation loss, such as regularization losses.  Note\n        that, the `mode` arg is not used in the `tf.estimator.*Head`. If the\n        update of the metrics doesn't rely on `mode`, it can be safely ignored\n        in the method signature.\n\n    Returns:\n       A `dict` of updated metrics keyed by name. The value is an instance of\n       `Metric` class.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n  def _summary_key(self, key):\n    return '{}/{}'.format(key, self.name) if self.name else key\n\n  def create_estimator_spec(self,\n                            features,\n                            mode,\n                            logits,\n                            labels=None,\n                            optimizer=None,\n                            trainable_variables=None,\n                            train_op_fn=None,\n                            update_ops=None,\n                            regularization_losses=None):\n    \"\"\"Returns `EstimatorSpec` that a model_fn can return.\n\n    It is recommended to pass all args via name.\n\n    Args:\n      features: Input `dict` mapping string feature names to `Tensor` or\n        `SparseTensor` objects containing the values for that feature in a\n        minibatch. Often to be used to fetch example-weight tensor.\n      mode: Estimator's `ModeKeys`.\n      logits: Logits `Tensor` to be used by the head.\n      labels: Labels `Tensor`, or `dict` mapping string label names to `Tensor`\n        objects of the label values.\n      optimizer: An `tf_keras.optimizers.Optimizer` instance to optimize the\n        loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss,\n        trainable_variables)`, which updates variables to minimize `loss`.\n      trainable_variables: A list or tuple of `Variable` objects to update to\n        minimize `loss`. In Tensorflow 1.x, by default these are the list of\n        variables collected in the graph under the key\n        `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have\n        collections and GraphKeys, trainable_variables need to be passed\n        explicitly here.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op\n        to optimize the model with the loss in TRAIN mode. Used if `optimizer`\n        is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in\n        TRAIN mode. By default, it is `None` in other modes. If you want to\n        optimize loss yourself, you can pass `lambda _: tf.no_op()` and then use\n          `EstimatorSpec.loss` to compute and apply gradients.\n      update_ops: A list or tuple of update ops to be run at training time. For\n        example, layers such as BatchNormalization create mean and variance\n        update ops that need to be run at training time. In Tensorflow 1.x,\n        these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x\n        doesn't have collections, update_ops need to be passed explicitly here.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses.\n\n    Returns:\n      `EstimatorSpec`.\n    \"\"\"\n    # Not all subclasses of Head will have implemented\n    # _create_tpu_estimator_spec. If it is implemented, we can convert it to\n    # the normal `EstimatorSpec` by calling the method of\n    # `_TPUEstimatorSpec.as_estimator_spec()`.\n    try:\n      tpu_estimator_spec = (\n          self._create_tpu_estimator_spec(\n              features=features,\n              mode=mode,\n              logits=logits,\n              labels=labels,\n              optimizer=optimizer,\n              trainable_variables=trainable_variables,\n              train_op_fn=train_op_fn,\n              update_ops=update_ops,\n              regularization_losses=regularization_losses))\n      return tpu_estimator_spec.as_estimator_spec()\n    except NotImplementedError:\n      raise NotImplementedError(\n          'Subclasses of Head must implement `create_estimator_spec()` or '\n          '_create_tpu_estimator_spec().')\n\n  def _create_tpu_estimator_spec(\n      self,\n      features,\n      mode,\n      logits,\n      labels=None,\n      optimizer=None,\n      trainable_variables=None,\n      train_op_fn=None,\n      update_ops=None,\n      regularization_losses=None,\n  ):\n    \"\"\"Returns `model_fn._TPUEstimatorSpec` that a model_fn can return.\n\n    Args:\n      features: Input `dict` mapping string feature names to `Tensor` or\n        `SparseTensor` objects containing the values for that feature in a\n        minibatch. Often to be used to fetch example-weight tensor.\n      mode: Estimator's `ModeKeys`.\n      logits: Logits `Tensor` to be used by the head.\n      labels: Labels `Tensor`, or `dict` mapping string label names to `Tensor`\n        objects of the label values.\n      optimizer: An `tf_keras.optimizers.Optimizer` instance to optimize the\n        loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss,\n        trainable_variables)`, which updates variables to minimize `loss`.\n      trainable_variables: A list or tuple of `Variable` objects to update to\n        minimize `loss`. In Tensorflow 1.x, by default these are the list of\n        variables collected in the graph under the key\n        `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have\n        collections and GraphKeys, trainable_variables need to be passed\n        explicitly here.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op\n        to optimize the model with the loss in TRAIN mode. Used if `optimizer`\n        is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in\n        TRAIN mode. By default, it is `None` in other modes. If you want to\n        optimize loss yourself, you can pass `lambda _: tf.no_op()` and then use\n          `EstimatorSpec.loss` to compute and apply gradients.\n      update_ops: A list or tuple of update ops to be run at training time. For\n        example, layers such as BatchNormalization create mean and variance\n        update ops that need to be run at training time. In Tensorflow 1.x,\n        these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x\n        doesn't have collections, update_ops need to be passed explicitly here.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses.\n\n    Returns:\n      A `model_fn._TPUEstimatorSpec' instance.\n    \"\"\"\n    raise NotImplementedError(\n        'TPUEstimatorSpec not available for this model head.')\n\n\n# TODO(b/119617064): unify eager and graph implementations\n# Note that, tensor shape checking is slow in Eager mode. To amend it, the\n# tensor static shape is used for checking. The duplication of shape checking\n# for eager mode in the following helper functions can be safely removed\n# if there's some way to get around it in the future.\n\n# Label shape error messages.\n_LABEL_NONE_ERR_MSG = (\n    'You must provide a labels Tensor. Given: None. '\n    'Suggested troubleshooting steps: Check that your data contains your label '\n    'feature. Check that your input_fn properly parses and returns labels.')\n\n_SPARSE_LABEL_ERR_MSG = (\n    'SparseTensor labels are not supported. Labels must be a Tensor of shape '\n    '[D0, D1, ..., DN, {}], e.g. [batch_size, {}].Suggested Fix (1): Check the'\n    ' label feature in your data. Each example must contain {} value(s). If '\n    'not, your choice of label was probably incorrect. Suggested Fix (2): In '\n    'your input_fn, use tf.sparse_tensor_to_dense() to turn labels into a '\n    'Tensor.')\n\n_MISMATCHED_LABEL_DIM_ERR_MSG = (\n    'Mismatched label shape. Expected labels dimension={}.  Received {}. '\n    'Suggested Fix: If your classifier expects one-hot encoding label, check '\n    'your n_classes argument to the estimator and/or the shape of your label. '\n    'Otherwise, check the shape of your label.')\n\n_LABEL_SHAPE_ERR_MSG = (\n    'labels shape must be [D0, D1, ... DN, {}]. Suggested Fix: check your '\n    'n_classes argument to the head and/or the shape of your label.')\n\n_VALIDATION_ERROR_MSG = '{} should be a list or a tuple. Given type: {}.'\n\n\ndef check_dense_labels_match_logits_and_reshape(labels, logits,\n                                                expected_labels_dimension):\n  \"\"\"Checks labels shape matches logits, and reshapes if needed.\n\n  Consider logits of shape [D0, D1, ... DN, logits_dimension]. Then labels\n  shape must be [D0, D1, ... DN, expected_labels_dimension].\n  If expected_labels_dimension=1, labels could be [D0, D1, ... DN] and this\n  method reshapes them to [D0, D1, ... DN, 1].\n\n  Args:\n    labels: labels Tensor.\n    logits: logits Tensor.\n    expected_labels_dimension: Integer.\n\n  Returns:\n    Validated and reshaped labels Tensor.\n\n  Raises:\n    ValueError: If labels is a SparseTensor.\n    ValueError: If labels shape is statically defined and fails validation.\n    OpError: If labels shape is not statically defined and fails validation.\n  \"\"\"\n  if labels is None:\n    raise ValueError(_LABEL_NONE_ERR_MSG)\n  with ops.name_scope('labels', values=(labels, logits)) as scope:\n    labels = tf.compat.v1.convert_to_tensor_or_sparse_tensor(labels)\n    if isinstance(labels, tf.sparse.SparseTensor):\n      raise ValueError(\n          _SPARSE_LABEL_ERR_MSG.format(expected_labels_dimension,\n                                       expected_labels_dimension,\n                                       expected_labels_dimension))\n    # Eager mode.\n    if tf.executing_eagerly():\n      labels_rank = labels._rank()  # pylint: disable=protected-access\n      logits_rank = logits._rank()  # pylint: disable=protected-access\n      if (labels_rank is not None and logits_rank is not None and\n          labels_rank == logits_rank - 1):\n        labels = tf.compat.v1.expand_dims(labels, -1)\n        labels_rank += 1\n      labels_shape = labels._shape_tuple()  # pylint: disable=protected-access\n      if labels_rank < 2:\n        raise ValueError('labels must have rank at least 2.  Received rank {}, '\n                         'shape {}'.format(labels_rank, labels_shape))\n      if labels_shape[-1] != expected_labels_dimension:\n        raise ValueError(\n            _MISMATCHED_LABEL_DIM_ERR_MSG.format(expected_labels_dimension,\n                                                 labels_shape[-1]))\n      logits_shape = logits._shape_tuple()  # pylint: disable=protected-access\n      expected_labels_shape = logits_shape[:-1] + (expected_labels_dimension,)\n      if expected_labels_shape != labels_shape:\n        raise ValueError(\n            '{}, expected_labels_shape: {}. labels_shape: {}.'.format(\n                _LABEL_SHAPE_ERR_MSG.format(expected_labels_dimension),\n                expected_labels_shape, labels_shape))\n      return labels\n\n    # Graph mode.\n    if (labels.shape.ndims is not None and logits.shape.ndims is not None and\n        labels.shape.ndims == logits.shape.ndims - 1):\n      labels = tf.compat.v1.expand_dims(labels, -1)\n    assert_rank = tf.compat.v1.debugging.assert_rank_at_least(\n        labels,\n        2,\n        message=_LABEL_SHAPE_ERR_MSG.format(expected_labels_dimension))\n    with tf.control_dependencies([assert_rank]):\n      static_shape = labels.shape\n      if static_shape.ndims is not None:\n        final_dim = static_shape[-1]\n        if (final_dim is not None) and (final_dim != expected_labels_dimension):\n          raise ValueError(\n              _MISMATCHED_LABEL_DIM_ERR_MSG.format(expected_labels_dimension,\n                                                   final_dim))\n      logits_shape = tf.compat.v1.shape(logits)\n      expected_labels_shape = tf.concat(\n          [logits_shape[:-1], [expected_labels_dimension]], axis=0)\n      labels_shape = tf.compat.v1.shape(labels)\n      assert_dimension = tf.compat.v1.debugging.assert_equal(\n          expected_labels_shape,\n          labels_shape,\n          message=_LABEL_SHAPE_ERR_MSG.format(expected_labels_dimension),\n          data=[\n              'expected_labels_shape: ', expected_labels_shape,\n              'labels_shape: ', labels_shape\n          ])\n      with tf.control_dependencies([assert_dimension]):\n        return tf.identity(labels, name=scope)\n\n\ndef get_weights_and_check_match_logits(features,\n                                       weight_column,\n                                       logits,\n                                       allow_per_logit_weights=False):\n  \"\"\"Fetches weights from features and checks that the shape matches logits.\n\n  Consider logits of shape [D0, D1, ... DN, logits_dimension]. Weights shape\n  can be either:\n  * [D0, D1, ... DN, logits_dimension] if `allow_per_logit_weights=True`.\n  * [D0, D1, ... DN, 1]\n  * [D0, D1, ... DN]: In this case, weights is reshaped into\n    [D0, D1, ... DN, 1] to work with weight broadcasting rules.\n\n  Args:\n    features: The features dict that contains weights.\n    weight_column: The weight column. If not given, this method returns 1.\n    logits: logits Tensor.\n    allow_per_logit_weights: Boolean. Whether we allow weights along the logits\n      dimension, namely shape `[D0, D1, ... DN, logits_dimension]`.\n\n  Returns:\n    Validated and reshaped weights Tensor.\n\n  Raises:\n    ValueError: If the weights `Tensor` cannot be cast into float.\n  \"\"\"\n  if allow_per_logit_weights:\n    err_msg = ('weights shape must be [D0, D1, ... DN], [D0, D1, ... DN, 1] or '\n               '[D0, D1, ... DN, logits_dimension]')\n  else:\n    err_msg = ('weights shape must be [D0, D1, ... DN] or [D0, D1, ... DN, 1]')\n  with ops.name_scope(\n      'weights', values=tuple(six.itervalues(features)) + (logits,)) as scope:\n    # Fetch the weights.\n    if weight_column is None:\n      return 1.\n    # TODO(b/117839674): update feature_column\n    if isinstance(weight_column, six.string_types):\n      weight_column = tf.feature_column.numeric_column(\n          key=weight_column, shape=(1,))\n    if not isinstance(weight_column,\n                      (feature_column_lib.NumericColumn, _NumericColumn)):\n      raise TypeError('Weight column must be either a string or NumericColumn.'\n                      ' Given type: {}.'.format(type(weight_column)))\n    weights = weight_column._get_dense_tensor(  # pylint: disable=protected-access\n        _LazyBuilder(features))\n    if not (weights.dtype.is_floating or weights.dtype.is_integer):\n      raise ValueError('Weight column should be castable to float. '\n                       'Given dtype: {}'.format(weights.dtype))\n    weights = tf.cast(weights, name='weights', dtype=tf.dtypes.float32)\n    # Validate the weights shape.\n    # Eager mode.\n    if tf.executing_eagerly():\n      weights_shape = weights._shape_tuple()  # pylint: disable=protected-access\n      logits_shape = logits._shape_tuple()  # pylint: disable=protected-access\n      weights_rank = weights._rank()  # pylint: disable=protected-access\n      logits_rank = logits._rank()  # pylint: disable=protected-access\n      if (weights_rank is not None and logits_rank is not None and\n          weights_rank == logits_rank - 1):\n        if logits_shape[:-1] != weights_shape:\n          raise ValueError('{}, logits_shape: {}. weights_shape: {}.'.format(\n              err_msg, logits_shape, weights_shape))\n        return tf.compat.v1.expand_dims(weights, -1, name=scope)\n      supported_weights_shape = logits_shape[:-1] + (1,)\n      if allow_per_logit_weights:\n        if (logits_shape != weights_shape and\n            supported_weights_shape != weights_shape):\n          raise ValueError('{}, logits_shape: {}. weights_shape: {}.'.format(\n              err_msg, logits_shape, weights_shape))\n      else:\n        if supported_weights_shape != weights_shape:\n          raise ValueError('{}, logits_shape: {}. weights_shape: {}.'.format(\n              err_msg, logits_shape, weights_shape))\n      return weights\n\n    # Graph mode.\n    weights_shape = tf.compat.v1.shape(weights, name='weights_shape')\n    logits_shape = tf.compat.v1.shape(logits, name='logits_shape')\n    if (weights.shape.ndims is not None and logits.shape.ndims is not None and\n        weights.shape.ndims == logits.shape.ndims - 1):\n      assert_dimension = tf.compat.v1.debugging.assert_equal(\n          logits_shape[:-1],\n          weights_shape,\n          message=err_msg,\n          data=[\n              'logits_shape: ', logits_shape, 'weights_shape: ', weights_shape\n          ])\n      with tf.control_dependencies([assert_dimension]):\n        return tf.compat.v1.expand_dims(weights, -1, name=scope)\n    supported_weights_shape = tf.concat([logits_shape[:-1], [1]], axis=0)\n    if allow_per_logit_weights:\n      condition = tf.math.reduce_any([\n          tf.reduce_all(tf.math.equal(logits_shape, weights_shape)),\n          tf.reduce_all(tf.math.equal(supported_weights_shape, weights_shape))\n      ])\n      assert_dimension = tf.debugging.Assert(\n          condition=condition,\n          data=[\n              err_msg, 'logits_shape: ', logits_shape, 'weights_shape: ',\n              weights_shape\n          ])\n    else:\n      assert_dimension = tf.compat.v1.debugging.assert_equal(\n          supported_weights_shape,\n          weights_shape,\n          message=err_msg,\n          data=[\n              'logits_shape: ', logits_shape, 'weights_shape: ', weights_shape\n          ])\n    with tf.control_dependencies([assert_dimension]):\n      return tf.identity(weights, name=scope)\n\n\ndef check_logits_final_dim(logits, expected_logits_dimension):\n  \"\"\"Checks that logits shape is [D0, D1, ... DN, logits_dimension].\"\"\"\n  with ops.name_scope('logits', values=(logits,)) as scope:\n    logits = tf.cast(logits, tf.dtypes.float32)\n    # Eager mode\n    if tf.executing_eagerly():\n      logits_shape = logits._shape_tuple()  # pylint: disable=protected-access\n      logits_rank = logits._rank()  # pylint: disable=protected-access\n      if logits_rank < 2:\n        raise ValueError('logits must have rank at least 2.  Received rank {}, '\n                         'shape {}'.format(logits_rank, logits_shape))\n      if (isinstance(expected_logits_dimension, int) and\n          logits_shape[-1] != expected_logits_dimension):\n        raise ValueError(\n            'logits shape must be [D0, D1, ... DN, logits_dimension], '\n            'got {}.'.format(logits_shape))\n      return logits\n    # Graph mode\n    logits_shape = tf.compat.v1.shape(logits)\n    assert_rank = tf.compat.v1.debugging.assert_rank_at_least(\n        logits,\n        2,\n        data=[logits_shape],\n        message='logits shape must be [D0, D1, ... DN, logits_dimension]')\n    with tf.control_dependencies([assert_rank]):\n      static_shape = logits.shape\n      if static_shape.ndims is not None and static_shape[-1] is not None:\n        if (isinstance(expected_logits_dimension, int) and\n            static_shape[-1] != expected_logits_dimension):\n          raise ValueError(\n              'logits shape must be [D0, D1, ... DN, logits_dimension], '\n              'got {}.'.format(static_shape))\n        return logits\n      assert_dimension = tf.compat.v1.debugging.assert_equal(\n          expected_logits_dimension,\n          logits_shape[-1],\n          data=[logits_shape],\n          message='logits shape must be [D0, D1, ... DN, logits_dimension]')\n      with tf.control_dependencies([assert_dimension]):\n        return tf.identity(logits, name=scope)\n\n\ndef validate_loss_fn_args(loss_fn):\n  \"\"\"Validates loss_fn arguments.\n\n  Required arguments: labels, logits.\n  Optional arguments: features, loss_reduction.\n\n  Args:\n    loss_fn: The loss function.\n\n  Raises:\n    ValueError: If the signature is unexpected.\n  \"\"\"\n  loss_fn_args = function_utils.fn_args(loss_fn)\n  for required_arg in ['labels', 'logits']:\n    if required_arg not in loss_fn_args:\n      raise ValueError('loss_fn must contain argument: {}. '\n                       'Given arguments: {}'.format(required_arg, loss_fn_args))\n  invalid_args = list(\n      set(loss_fn_args) -\n      set(['labels', 'logits', 'features', 'loss_reduction']))\n  if invalid_args:\n    raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args))\n\n\ndef validate_loss_reduction(loss_reduction):\n  if (loss_reduction not in tf.losses.Reduction.all() or\n      loss_reduction == tf.losses.Reduction.NONE):\n    raise ValueError(\n        'Invalid loss_reduction: {}. See `tf.losses.Reduction` for valid '\n        'options.'.format(loss_reduction))\n\n\ndef validate_update_ops(update_ops=None):\n  if update_ops is not None and not isinstance(update_ops, (list, tuple)):\n    raise ValueError(\n        _VALIDATION_ERROR_MSG.format('update_ops', type(update_ops)))\n\n\ndef validate_v2_optimizer(optimizer):\n  if not isinstance(\n      optimizer,\n      (tf_keras.optimizers.Optimizer, tf_keras.optimizers.legacy.Optimizer)):\n    raise ValueError(\n        'The given optimizer is not a tf_keras.optimizers.Optimizer '\n        f'instance. Received optimizer of type {type(optimizer)}')\n\n\ndef validate_trainable_variables(trainable_variables=None):\n  if trainable_variables is None:\n    raise ValueError('trainable_variables cannot be None. Given {}'.format(\n        trainable_variables))\n  if not isinstance(trainable_variables, (list, tuple)):\n    raise ValueError(\n        _VALIDATION_ERROR_MSG.format('trainable_variables',\n                                     type(trainable_variables)))\n\n\ndef validate_n_classes(n_classes):\n  \"\"\"Validates n_classes argument.\n\n  Required arguments: n_classes.\n\n  Args:\n    n_classes: The number of classes.\n\n  Raises:\n    ValueError: If n_classes is <= 2 and n_classes is a Python integer.\n  Returns:\n    n_classes in its original type.\n  \"\"\"\n  if isinstance(n_classes, int) and (n_classes <= 2):\n    raise ValueError('n_classes must be > 2: %s.' % n_classes)\n\n  n_classes_as_tensor = ops.convert_to_tensor(n_classes)\n  assert_n_classes = tf.compat.v1.debugging.assert_greater(\n      n_classes_as_tensor, 2, message='n_classes must be greater than 2')\n  with tf.control_dependencies([assert_n_classes]):\n    tf.no_op()\n  # Return n_classes in its original type, so that any code\n  # using the accessor logits_dimension() has the original type.\n  return n_classes\n\n\ndef call_loss_fn(loss_fn, labels, logits, features, expected_loss_dim=1):\n  \"\"\"Calls loss_fn and checks the returned shape.\n\n  For shape checking, eager uses the static dimension to improve performance.\n\n  Args:\n    loss_fn: The loss function.\n    labels: Processed labels Tensor.\n    logits: Logits Tensor of shape [D0, D1, ... DN, logits_dimension].\n    features: Features dict.\n    expected_loss_dim: The expected last dimension of loss Tensor.\n\n  Returns:\n    Loss Tensor with shape [D0, D1, ... DN, expected_loss_dim].\n\n  Raises:\n    ValueError: If the loss tensor shape is unexpected.\n  \"\"\"\n  loss_fn_args = function_utils.fn_args(loss_fn)\n  kwargs = {}\n  if 'features' in loss_fn_args:\n    kwargs['features'] = features\n  with ops.name_scope(\n      'call_loss_fn', values=[labels, logits] + list(six.itervalues(features))):\n    unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs)\n    # Eager mode.\n    if tf.executing_eagerly():\n      loss_shape = unweighted_loss._shape_tuple()  # pylint: disable=protected-access\n      logits_shape = logits._shape_tuple()  # pylint: disable=protected-access\n      expected_loss_shape = logits_shape[:-1] + (expected_loss_dim,)\n      if loss_shape != expected_loss_shape:\n        raise ValueError(\n            'loss_fn must return Tensor of shape '\n            '[D0, D1, ... DN, {}]. '.format(expected_loss_dim),\n            'logits_shape: ', logits_shape, 'loss_shape: ', loss_shape)\n      return unweighted_loss\n    # Graph mode.\n    logits_shape = tf.compat.v1.shape(logits, name='logits_shape')\n    expected_loss_shape = tf.concat([logits_shape[:-1], [expected_loss_dim]],\n                                    axis=0,\n                                    name='expected_loss_shape')\n    loss_shape = tf.compat.v1.shape(unweighted_loss, name='loss_shape')\n    check_loss_shape_op = tf.debugging.Assert(\n        tf.reduce_all(tf.math.equal(loss_shape, expected_loss_shape)),\n        data=[\n            'loss_fn must return Tensor of shape '\n            '[D0, D1, ... DN, {}]. '.format(expected_loss_dim),\n            'logits_shape: ', logits_shape, 'loss_shape: ', loss_shape\n        ],\n        name='check_loss_shape')\n    with tf.control_dependencies([check_loss_shape_op]):\n      return tf.identity(unweighted_loss)\n\n\ndef check_prediction_keys(pred_keys, valid_keys):\n  for key in pred_keys:\n    if key not in valid_keys:\n      raise ValueError('Prediction key must be in PredictionKeys, given: {}.'\n                       'Valid prediction keys include {}.'.format(\n                           key, valid_keys))\n\n\ndef all_class_ids(logits, n_classes):\n  batch_size = tf.compat.v1.shape(logits)[0]\n  class_id_list = tf.range(n_classes)\n  return tf.tile(\n      input=tf.compat.v1.expand_dims(input=class_id_list, axis=0),\n      multiples=[batch_size, 1])\n\n\ndef all_classes(logits, n_classes, label_vocabulary=None):\n  batch_size = tf.compat.v1.shape(logits)[0]\n  if label_vocabulary:\n    classes_list = tf.convert_to_tensor([label_vocabulary])\n  else:\n    classes_list = tf.expand_dims(tf.range(n_classes), axis=0)\n    classes_list = tf.strings.as_string(classes_list)\n  return tf.tile(input=classes_list, multiples=[batch_size, 1])\n\n\ndef classification_output(scores, n_classes, label_vocabulary=None):\n  return export_output.ClassificationOutput(\n      scores=scores,\n      # `ClassificationOutput` requires string classes.\n      classes=all_classes(scores, n_classes, label_vocabulary))\n\n\ndef check_label_range(labels, n_classes, message=None):\n  \"\"\"Check if labels are in the range of [0, n_classes).\"\"\"\n  with ops.name_scope('check_label_range', values=(labels,)):\n    # Eager mode\n    if tf.executing_eagerly():\n      assert_less = tf.reduce_all(tf.math.less_equal(labels, n_classes - 1))\n      if not assert_less:\n        raise ValueError(message or\n                         'Labels must be <= {} - 1'.format(n_classes))\n      assert_greater = tf.reduce_all(tf.math.greater_equal(labels, 0))\n      if not assert_greater:\n        raise ValueError(message or 'Labels must be >= 0')\n      return labels\n    # Graph mode\n    assert_less = tf.compat.v1.debugging.assert_less_equal(\n        labels,\n        ops.convert_to_tensor(n_classes - 1, dtype=labels.dtype),\n        message=message or 'Labels must be <= n_classes - 1')\n    assert_greater = tf.compat.v1.debugging.assert_non_negative(\n        labels, message=message or 'Labels must be >= 0')\n    with tf.control_dependencies((assert_less, assert_greater)):\n      return tf.identity(labels)\n\n\ndef update_metric_with_broadcast_weights(eval_metric, values, weights):\n  values = tf.cast(values, dtype=tf.dtypes.float32)\n  if weights is not None:\n    weights = tf.compat.v2.__internal__.ops.broadcast_weights(weights, values)\n  eval_metric.update_state(values=values, sample_weight=weights)\n\n\ndef create_eval_metrics_tuple(fn, kwargs):\n  \"\"\"Creates TPU eval metrics tuple.\n\n  Helper function to make eval_metric tuple (eval_metric_fn, fn_kwargs) used\n  by `TPUEstimator`. TPUEstimator requires that `eval_metric_fn` take\n  exclusively Tensor arguments. This helper can help create such a function from\n  a more generic function that can take both Tensor and non-Tensor arguments.\n\n  Args:\n    fn: A eval_metric_fn that takes both Tensor and non-Tensor arguments. This\n      function must return a dict of form\n        {'metric name': (metric_tensor, eval_op)}\n    kwargs: Dict of arguments for `fn`.\n\n  Returns:\n    `eval_metric` tuple that can be passed to a `model_fn._TPUEstimatorSpec`.\n  \"\"\"\n  tensor_kwargs = {}\n  nontensor_kwargs = {}\n  for k, v in six.iteritems(kwargs):\n    if tf.is_tensor(v):\n      tensor_kwargs[k] = v\n    else:\n      nontensor_kwargs[k] = v\n\n  def _fn(**tensors):\n    return fn(**dict(nontensor_kwargs, **tensors))\n\n  return (_fn, tensor_kwargs)\n\n\ndef create_estimator_spec_train_op(\n    head_name,\n    optimizer=None,\n    trainable_variables=None,\n    train_op_fn=None,\n    update_ops=None,\n    regularized_training_loss=None,\n    loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE):\n  \"\"\"Create train_op for estimator_spec.\n\n  Args:\n    head_name: The name of the head.\n    optimizer: An `tf_keras.optimizers.Optimizer` instance to optimize the loss\n      in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss,\n      trainable_variables)`, which updates variables to minimize `loss`.\n    trainable_variables: A list or tuple of `Variable` objects to update to\n      minimize `loss`. In Tensorflow 1.x, by default these are the list of\n      variables collected in the graph under the key\n      `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have\n      collections and GraphKeys, trainable_variables need to be passed\n      explicitly here.\n    train_op_fn: Function that takes a scalar loss `Tensor` and returns\n      `train_op`. Used if `optimizer` is `None`.\n    update_ops: A list or tuple of update ops to be run at training time. For\n      example, layers such as BatchNormalization create mean and variance update\n      ops that need to be run at training time. In Tensorflow 1.x, these are\n      thrown into an UPDATE_OPS collection. As Tensorflow 2.x doesn't have\n      collections, update_ops need to be passed explicitly here.\n    regularized_training_loss: A scalar for total training loss that includes\n      all regularization losses. If you're not using optimizer to generate train\n      op, make sure to scale the loss correctly before passing it in. The loss\n      typically needs to be scaled down by the number of workers.\n    loss_reduction: One of `tf_keras.losses.Reduction` except `NONE`. Describes\n      how to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n\n  Returns:\n    A train op for EstimatorSpec.\n  \"\"\"\n  del head_name\n  validate_update_ops(update_ops)\n  with ops.name_scope(''):  # Reset all previous name_scope.\n    # Add training as the name_scope to be compatible with Keras.\n    with ops.name_scope('training'):\n      if optimizer is not None:\n        if train_op_fn is not None:\n          raise ValueError('train_op_fn and optimizer cannot both be set.')\n        validate_v2_optimizer(optimizer)\n        validate_trainable_variables(trainable_variables)\n        # Scale loss by number of replicas.\n        if loss_reduction == tf.losses.Reduction.SUM_OVER_BATCH_SIZE:\n          num_replicas = tf.distribute.get_strategy().num_replicas_in_sync\n          if num_replicas > 1:\n            regularized_training_loss *= (1. / num_replicas)\n        train_op = optimizer.get_updates(regularized_training_loss,\n                                         trainable_variables)[0]\n      elif train_op_fn is not None:\n        train_op = train_op_fn(regularized_training_loss)\n      else:\n        raise ValueError('train_op_fn and optimizer cannot both be None.')\n      if update_ops is not None:\n        train_op = tf.group(train_op, *update_ops)\n      return train_op\n\n\ndef create_estimator_spec_summary(regularized_training_loss,\n                                  regularization_losses=None,\n                                  summary_key_fn=None):\n  \"\"\"Create summary for estimator_spec.\"\"\"\n  with ops.name_scope(''):\n    keys = metric_keys.MetricKeys\n    loss_key = summary_key_fn(keys.LOSS) if summary_key_fn else keys.LOSS\n    tf.compat.v1.summary.scalar(loss_key, regularized_training_loss)\n    if regularization_losses is not None:\n      regularization_loss = tf.math.add_n(regularization_losses)\n      regularization_loss_key = (\n          summary_key_fn(keys.LOSS_REGULARIZATION)\n          if summary_key_fn else keys.LOSS_REGULARIZATION)\n      tf.compat.v1.summary.scalar(regularization_loss_key, regularization_loss)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/base_head_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for base_head.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.head import base_head\nfrom tensorflow_estimator.python.estimator.head import binary_class_head as head_lib\nfrom tensorflow_estimator.python.estimator.head import head_utils as test_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n_DEFAULT_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n\n\ndef _assert_simple_summaries(test_case,\n                             expected_summaries,\n                             summary_str,\n                             tol=1e-6):\n  \"\"\"Assert summary the specified simple values.\n\n  Args:\n    test_case: test case.\n    expected_summaries: Dict of expected tags and simple values.\n    summary_str: Serialized `summary_pb2.Summary`.\n    tol: Tolerance for relative and absolute.\n  \"\"\"\n  summary = tf.compat.v1.summary.Summary()\n  summary.ParseFromString(summary_str)\n  test_case.assertAllClose(\n      expected_summaries, {v.tag: v.simple_value for v in summary.value},\n      rtol=tol,\n      atol=tol)\n\n\ndef _assert_no_hooks(test_case, spec):\n  test_case.assertAllEqual([], spec.training_chief_hooks)\n  test_case.assertAllEqual([], spec.training_hooks)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass CreateEstimatorSpecTest(tf.test.TestCase):\n\n  class _HeadWithTPUSupport(base_head.Head):\n    \"\"\"Head that overrides _create_tpu_estimator_spec.\"\"\"\n\n    def name(self):\n      return 'HeadWithTPUSupport'\n\n    def logits_dimension(self):\n      return None\n\n    def loss_reduction(self):\n      return None\n\n    def loss(self, features, mode, logits, labels):\n      return None\n\n    def predictions(self, logits):\n      return None\n\n    def metrics(self, regularization_losses=None):\n      return None\n\n    def update_metrics(self,\n                       eval_metrics,\n                       features,\n                       logits,\n                       labels,\n                       mode=None,\n                       regularization_losses=None):\n      return None\n\n    def _create_tpu_estimator_spec(\n        self,\n        features,\n        mode,\n        logits,\n        labels=None,\n        optimizer=None,\n        trainable_variables=None,\n        train_op_fn=None,\n        update_ops=None,\n        regularization_losses=None,\n    ):\n      return model_fn._TPUEstimatorSpec(\n          mode=ModeKeys.EVAL, loss=tf.constant(0.0, dtype=tf.dtypes.float32))\n\n  class _HeadWithOutTPUSupport(base_head.Head):\n    \"\"\"Head that overrides create_estimator_spec.\"\"\"\n\n    def name(self):\n      return 'HeadWithOutTPUSupport'\n\n    def logits_dimension(self):\n      return None\n\n    def loss_reduction(self):\n      return None\n\n    def loss(self, features, mode, logits, labels):\n      return None\n\n    def predictions(self, logits):\n      return None\n\n    def metrics(self, regularization_losses=None):\n      return None\n\n    def update_metrics(self,\n                       eval_metrics,\n                       features,\n                       logits,\n                       labels,\n                       mode=None,\n                       regularization_losses=None):\n      return None\n\n    def create_estimator_spec(\n        self,\n        features,\n        mode,\n        logits,\n        labels=None,\n        optimizer=None,\n        trainable_variables=None,\n        train_op_fn=None,\n        update_ops=None,\n        regularization_losses=None,\n    ):\n      return model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL, loss=tf.constant(0.0, dtype=tf.dtypes.float32))\n\n  class _InvalidHead(base_head.Head):\n    \"\"\"Head that overrides neither estimator_spec functions.\"\"\"\n\n    def name(self):\n      return 'InvalidHead'\n\n    def logits_dimension(self):\n      return None\n\n    def loss_reduction(self):\n      return None\n\n    def loss(self, features, mode, logits, labels):\n      return None\n\n    def predictions(self, logits):\n      return None\n\n    def metrics(self, regularization_losses=None):\n      return None\n\n    def update_metrics(self,\n                       eval_metrics,\n                       features,\n                       logits,\n                       labels,\n                       mode=None,\n                       regularization_losses=None):\n      return None\n\n  def test_head_override_tpu_estimator_spec(self):\n    \"\"\"Test for `_Head` that overrides _create_tpu_estimator_spec.\"\"\"\n    head = self._HeadWithTPUSupport()\n\n    tpu_spec = head._create_tpu_estimator_spec(\n        features=None, mode=None, logits=None)\n    self.assertTrue(isinstance(tpu_spec, model_fn._TPUEstimatorSpec))\n    est_spec = head.create_estimator_spec(features=None, mode=None, logits=None)\n    self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))\n\n  def test_head_override_estimator_spec(self):\n    \"\"\"Test for `Head` that overrides create_estimator_spec.\"\"\"\n    head = self._HeadWithOutTPUSupport()\n\n    with self.assertRaisesRegexp(\n        NotImplementedError,\n        'TPUEstimatorSpec not available for this model head.'):\n      _ = head._create_tpu_estimator_spec(features=None, mode=None, logits=None)\n    est_spec = head.create_estimator_spec(features=None, mode=None, logits=None)\n    self.assertTrue(isinstance(est_spec, model_fn.EstimatorSpec))\n\n  def test_invalid_head_class(self):\n    head = self._InvalidHead()\n\n    with self.assertRaisesRegexp(\n        NotImplementedError,\n        'TPUEstimatorSpec not available for this model head.'):\n      _ = head._create_tpu_estimator_spec(features=None, mode=None, logits=None)\n    with self.assertRaisesRegexp(\n        NotImplementedError,\n        r'Subclasses of Head must implement `create_estimator_spec\\(\\)` or '\n        r'_create_tpu_estimator_spec\\(\\).'):\n      _ = head.create_estimator_spec(features=None, mode=None, logits=None)\n\n  @test_util.deprecated_graph_mode_only\n  def test_tensor_shape_checking_in_graph_mode(self):\n    \"\"\"Test for shape checking of tensor with partially defined shape.\"\"\"\n    labels_placeholder = tf.compat.v1.placeholder(\n        dtype=tf.dtypes.float32, shape=(None, 1))\n    logits_placeholder = tf.compat.v1.placeholder(\n        dtype=tf.dtypes.float32, shape=(None, 1))\n    labels_input = np.array([[-10.], [10.]], dtype=np.float32)\n    logits_input = np.array([[1.], [0.]], dtype=np.float32)\n\n    loss = np.array([[1.], [2.]], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      check_labels = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(labels, labels_input)), data=[labels])\n      check_logits = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(logits, logits_input)), data=[logits])\n      with tf.control_dependencies([check_labels, check_logits]):\n        return tf.constant(loss)\n\n    unweighted_loss = base_head.call_loss_fn(\n        loss_fn=_loss_fn,\n        labels=labels_placeholder,\n        logits=logits_placeholder,\n        features={'x': np.array(((42,),), dtype=np.int32)})\n    with self.cached_session():\n      self.assertAllClose(\n          unweighted_loss.eval({\n              labels_placeholder: labels_input,\n              logits_placeholder: logits_input\n          }), loss)\n\n  @test_util.deprecated_graph_mode_only\n  def test_optimizer_v2_variable_name(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n\n    class _Optimizer(tf_keras.optimizers.legacy.Optimizer):\n\n      def init(self, name, **kwargs):\n        super(_Optimizer, self).__init__(name, **kwargs)\n\n      def get_updates(self, loss, params):\n        del params\n        variable = tf.Variable(\n            name='my_variable', dtype=tf.dtypes.float32, initial_value=0.)\n        self._weights.append(variable)\n        return [variable]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    # Create estimator spec.\n    optimizer = _Optimizer('my_optimizer')\n    old_opt_variable_name_prefix = 'training/' + optimizer.__class__.__name__\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        optimizer=optimizer,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      optimizer_variables = optimizer.variables()\n      var_values = sess.run(optimizer_variables)\n      self.assertEqual(0., var_values[0])\n      for var in optimizer_variables:\n        self.assertNotIn(old_opt_variable_name_prefix, var.name)\n\n  @test_util.deprecated_graph_mode_only\n  def test_head_with_invalid_optimizer(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n\n    with self.assertRaisesRegex(\n        ValueError,\n        'The given optimizer is not a tf_keras.optimizers.Optimizer'):\n      # Create estimator spec.\n      head.create_estimator_spec(\n          features=features,\n          mode=ModeKeys.TRAIN,\n          logits=logits,\n          labels=labels,\n          optimizer=tf.compat.v1.train.AdamOptimizer())\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/binary_class_head.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Binary class head.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import lookup_ops\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.head import base_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v2\n\n\n@estimator_export('estimator.BinaryClassHead')\nclass BinaryClassHead(base_head.Head):\n  \"\"\"Creates a `Head` for single label binary classification.\n\n  Uses `sigmoid_cross_entropy_with_logits` loss.\n\n  The head expects `logits` with shape `[D0, D1, ... DN, 1]`.\n  In many applications, the shape is `[batch_size, 1]`.\n\n  `labels` must be a dense `Tensor` with shape matching `logits`, namely\n  `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string\n  `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,\n  `labels` must be float `Tensor` with values in the interval `[0, 1]`.\n\n  If `weight_column` is specified, weights must be of shape\n  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.\n\n  The loss is the weighted sum over the input dimensions. Namely, if the input\n  labels have shape `[batch_size, 1]`, the loss is the weighted sum over\n  `batch_size`.\n\n  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or\n  `(labels, logits, features, loss_reduction)` as arguments and returns loss\n  with shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with\n  shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to\n  the input labels before passing them to `loss_fn`.\n\n  Usage:\n\n  >>> head = tf.estimator.BinaryClassHead()\n  >>> logits = np.array(((45,), (-41,),), dtype=np.float32)\n  >>> labels = np.array(((1,), (1,),), dtype=np.int32)\n  >>> features = {'x': np.array(((42,),), dtype=np.float32)}\n  >>> # expected_loss = sum(cross_entropy(labels, logits)) / batch_size\n  >>> #               = sum(0, 41) / 2 = 41 / 2 = 20.50\n  >>> loss = head.loss(labels, logits, features=features)\n  >>> print('{:.2f}'.format(loss.numpy()))\n  20.50\n  >>> eval_metrics = head.metrics()\n  >>> updated_metrics = head.update_metrics(\n  ...   eval_metrics, features, logits, labels)\n  >>> for k in sorted(updated_metrics):\n  ...  print('{} : {:.2f}'.format(k, updated_metrics[k].result().numpy()))\n    accuracy : 0.50\n    accuracy_baseline : 1.00\n    auc : 0.00\n    auc_precision_recall : 1.00\n    average_loss : 20.50\n    label/mean : 1.00\n    precision : 1.00\n    prediction/mean : 0.50\n    recall : 0.50\n  >>> preds = head.predictions(logits)\n  >>> print(preds['logits'])\n  tf.Tensor(\n    [[ 45.]\n     [-41.]], shape=(2, 1), dtype=float32)\n\n  Usage with a canned estimator:\n\n  ```python\n  my_head = tf.estimator.BinaryClassHead()\n  my_estimator = tf.estimator.DNNEstimator(\n      head=my_head,\n      hidden_units=...,\n      feature_columns=...)\n  ```\n\n  It can also be used with a custom `model_fn`. Example:\n\n  ```python\n  def _my_model_fn(features, labels, mode):\n    my_head = tf.estimator.BinaryClassHead()\n    logits = tf_keras.Model(...)(features)\n\n    return my_head.create_estimator_spec(\n        features=features,\n        mode=mode,\n        labels=labels,\n        optimizer=tf_keras.optimizers.Adagrad(lr=0.1),\n        logits=logits)\n\n  my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)\n  ```\n\n  Args:\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example.\n    thresholds: Iterable of floats in the range `(0, 1)`. For binary\n      classification metrics such as precision and recall, an eval metric is\n      generated for each threshold value. This threshold is applied to the\n      logistic values to determine the binary classification (i.e., above the\n      threshold is `true`, below is `false`.\n    label_vocabulary: A list or tuple of strings representing possible label\n      values. If it is not given, that means labels are already encoded within\n      [0, 1]. If given, labels must be string type and have any value in\n      `label_vocabulary`. Note that errors will be raised if `label_vocabulary`\n      is not provided but labels are strings.\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Decides how to\n      reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely\n      weighted sum of losses divided by `batch size * label_dimension`.\n    loss_fn: Optional loss function.\n    name: Name of the head. If provided, summary and metrics keys will be\n      suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops.\n  \"\"\"\n\n  def __init__(self,\n               weight_column=None,\n               thresholds=None,\n               label_vocabulary=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               loss_fn=None,\n               name=None):\n    if label_vocabulary is not None and not isinstance(label_vocabulary,\n                                                       (list, tuple)):\n      raise ValueError(\n          'label_vocabulary should be a list or a tuple. Given type: {}'.format(\n              type(label_vocabulary)))\n    thresholds = tuple(thresholds) if thresholds else tuple()\n    for threshold in thresholds:\n      if (threshold <= 0.0) or (threshold >= 1.0):\n        raise ValueError('thresholds not in (0, 1): {}.'.format((thresholds,)))\n    base_head.validate_loss_reduction(loss_reduction)\n    if loss_fn:\n      base_head.validate_loss_fn_args(loss_fn)\n    self._weight_column = weight_column\n    self._thresholds = thresholds\n    self._label_vocabulary = label_vocabulary\n    self._loss_reduction = loss_reduction\n    self._loss_fn = loss_fn\n    self._name = name\n    # Metric keys.\n    keys = metric_keys.MetricKeys\n    self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)\n    self._accuracy_key = self._summary_key(keys.ACCURACY)\n    self._precision_key = self._summary_key(keys.PRECISION)\n    self._recall_key = self._summary_key(keys.RECALL)\n    self._prediction_mean_key = self._summary_key(keys.PREDICTION_MEAN)\n    self._label_mean_key = self._summary_key(keys.LABEL_MEAN)\n    self._accuracy_baseline_key = self._summary_key(keys.ACCURACY_BASELINE)\n    self._auc_key = self._summary_key(keys.AUC)\n    self._auc_pr_key = self._summary_key(keys.AUC_PR)\n    self._loss_regularization_key = self._summary_key(keys.LOSS_REGULARIZATION)\n    accuracy_keys = []\n    precision_keys = []\n    recall_keys = []\n    for threshold in self._thresholds:\n      accuracy_keys.append(\n          self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold))\n      precision_keys.append(\n          self._summary_key(keys.PRECISION_AT_THRESHOLD % threshold))\n      recall_keys.append(\n          self._summary_key(keys.RECALL_AT_THRESHOLD % threshold))\n    self._accuracy_keys = tuple(accuracy_keys)\n    self._precision_keys = tuple(precision_keys)\n    self._recall_keys = tuple(recall_keys)\n\n  @property\n  def name(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._name\n\n  @property\n  def logits_dimension(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return 1\n\n  @property\n  def loss_reduction(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._loss_reduction\n\n  # Attributes for lookup tables in Eager execution. Note that for Graph\n  # execution, the lookup tables are created on demand to make sure the lookup\n  # table is in the same graph as its input tensors for `train` and `eval` of\n  # Estimator (as Estimator recreates graphs for `train`, `eval` and\n  # `predict`).\n  _cached_class_id_table = None\n  _cached_class_string_table = None\n\n  @property\n  def _class_id_table(self):\n    \"\"\"Creates a lookup table for class_id.\n\n    In eager execution, this lookup table will be lazily created on the first\n    call of `self._class_id_table`, and cached for later use; In graph\n    execution, it will be created on demand.\n\n    Returns:\n      A hash table for lookup.\n    \"\"\"\n    if self._cached_class_id_table is None or not tf.executing_eagerly():\n      self._cached_class_id_table = lookup_ops.index_table_from_tensor(\n          vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup')\n    return self._cached_class_id_table\n\n  @property\n  def _class_string_table(self):\n    \"\"\"Creates a lookup table for class_string.\n\n    In eager execution, this lookup table will be lazily created on the first\n    call of `self._class_string_table` and cached for later use; In graph\n    execution, it will be created on demand.\n\n    Returns:\n      A hash table for lookup.\n    \"\"\"\n    if (self._cached_class_string_table is None or not tf.executing_eagerly()):\n      self._cached_class_string_table = (\n          lookup_ops.index_to_string_table_from_tensor(\n              vocabulary_list=self._label_vocabulary,\n              name='class_string_lookup'))\n    return self._cached_class_string_table\n\n  def _processed_labels(self, logits, labels):\n    \"\"\"Converts labels to integer id space.\"\"\"\n    labels = base_head.check_dense_labels_match_logits_and_reshape(\n        labels=labels, logits=logits, expected_labels_dimension=1)\n    if self._label_vocabulary is not None:\n      labels = self._class_id_table.lookup(labels)\n    labels = tf.cast(labels, dtype=tf.dtypes.float32)\n    return base_head.check_label_range(labels, n_classes=2)\n\n  def _unweighted_loss_and_weights(self, logits, labels, features):\n    \"\"\"Computes unweighted loss and weights.\"\"\"\n    if self._loss_fn:\n      unweighted_loss = base_head.call_loss_fn(\n          loss_fn=self._loss_fn,\n          labels=labels,\n          logits=logits,\n          features=features,\n          expected_loss_dim=1)\n    else:\n      unweighted_loss = tf.compat.v1.nn.sigmoid_cross_entropy_with_logits(\n          labels=labels, logits=logits)\n    weights = base_head.get_weights_and_check_match_logits(\n        features=features, weight_column=self._weight_column, logits=logits)\n    return unweighted_loss, weights\n\n  def loss(self,\n           labels,\n           logits,\n           features=None,\n           mode=None,\n           regularization_losses=None):\n    \"\"\"Returns regularized training loss. See `base_head.Head` for details.\"\"\"\n    del mode  # Unused for this head.\n    with ops.name_scope(\n        'losses', values=(logits, labels, regularization_losses, features)):\n      logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n      labels = self._processed_labels(logits, labels)\n      unweighted_loss, weights = self._unweighted_loss_and_weights(\n          logits, labels, features)\n      training_loss = tf_keras_v2.__internal__.losses.compute_weighted_loss(\n          unweighted_loss,\n          sample_weight=weights,\n          reduction=self._loss_reduction)\n      regularization_loss = tf.math.add_n(\n          regularization_losses) if regularization_losses is not None else None\n      regularized_training_loss = (\n          training_loss + regularization_loss\n          if regularization_loss is not None else training_loss)\n    return regularized_training_loss\n\n  def predictions(self, logits, keys=None):\n    \"\"\"Return predictions based on keys.\n\n    See `base_head.Head` for details.\n\n    Args:\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.\n        For many applications, the shape is `[batch_size, logits_dimension]`.\n      keys: a list or tuple of prediction keys. Each key can be either the class\n        variable of prediction_keys.PredictionKeys or its string value, such as:\n          prediction_keys.PredictionKeys.CLASSES or 'classes'. If not specified,\n          it will return the predictions for all valid keys.\n\n    Returns:\n      A dict of predictions.\n    \"\"\"\n    pred_keys = prediction_keys.PredictionKeys\n    valid_keys = [\n        pred_keys.LOGITS, pred_keys.LOGISTIC, pred_keys.PROBABILITIES,\n        pred_keys.CLASS_IDS, pred_keys.CLASSES, pred_keys.ALL_CLASS_IDS,\n        pred_keys.ALL_CLASSES\n    ]\n\n    if keys:\n      base_head.check_prediction_keys(keys, valid_keys)\n    else:\n      keys = valid_keys\n    logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n    predictions = {}\n    with ops.name_scope('predictions', values=(logits,)):\n      if pred_keys.LOGITS in keys:\n        predictions[pred_keys.LOGITS] = logits\n      if pred_keys.LOGISTIC in keys:\n        logistic = tf.math.sigmoid(logits, name=pred_keys.LOGISTIC)\n        predictions[pred_keys.LOGISTIC] = logistic\n      two_class_logits = tf.concat((tf.compat.v1.zeros_like(logits), logits),\n                                   axis=-1,\n                                   name='two_class_logits')\n      if pred_keys.PROBABILITIES in keys:\n        probabilities = tf.compat.v1.nn.softmax(\n            two_class_logits, name=pred_keys.PROBABILITIES)\n        predictions[pred_keys.PROBABILITIES] = probabilities\n      if pred_keys.CLASS_IDS in keys or pred_keys.CLASSES in keys:\n        class_ids = tf.compat.v1.math.argmax(\n            two_class_logits, axis=-1, name=pred_keys.CLASS_IDS)\n        class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1)\n        if pred_keys.CLASS_IDS in keys:\n          predictions[pred_keys.CLASS_IDS] = class_ids\n        if pred_keys.CLASSES in keys:\n          if self._label_vocabulary is not None:\n            classes = self._class_string_table.lookup(class_ids)\n          else:\n            classes = tf.strings.as_string(class_ids, name='str_classes')\n          predictions[pred_keys.CLASSES] = classes\n      if pred_keys.ALL_CLASS_IDS in keys:\n        predictions[pred_keys.ALL_CLASS_IDS] = base_head.all_class_ids(\n            logits, n_classes=2)\n      if pred_keys.ALL_CLASSES in keys:\n        predictions[pred_keys.ALL_CLASSES] = base_head.all_classes(\n            logits, n_classes=2, label_vocabulary=self._label_vocabulary)\n      return predictions\n\n  def metrics(self, regularization_losses=None):\n    \"\"\"Creates metrics. See `base_head.Head` for details.\"\"\"\n    keys = metric_keys.MetricKeys\n    with ops.name_scope('metrics', values=(regularization_losses,)):\n      # Mean metric.\n      eval_metrics = {}\n      eval_metrics[self._loss_mean_key] = tf_keras.metrics.Mean(\n          name=keys.LOSS_MEAN)\n      eval_metrics[self._accuracy_key] = tf_keras.metrics.Accuracy(\n          name=keys.ACCURACY)\n      eval_metrics[self._precision_key] = tf_keras.metrics.Precision(\n          name=keys.PRECISION)\n      eval_metrics[self._recall_key] = tf_keras.metrics.Recall(\n          name=keys.RECALL)\n      eval_metrics[self._prediction_mean_key] = tf_keras.metrics.Mean(\n          name=keys.PREDICTION_MEAN)\n      eval_metrics[self._label_mean_key] = tf_keras.metrics.Mean(\n          name=keys.LABEL_MEAN)\n      eval_metrics[self._accuracy_baseline_key] = tf_keras.metrics.Mean(\n          name=keys.ACCURACY_BASELINE)\n      # The default summation_method is \"interpolation\" in the AUC metric.\n      eval_metrics[self._auc_key] = tf_keras.metrics.AUC(name=keys.AUC)\n      eval_metrics[self._auc_pr_key] = tf_keras.metrics.AUC(\n          curve='PR', name=keys.AUC_PR)\n      if regularization_losses is not None:\n        eval_metrics[self._loss_regularization_key] = tf_keras.metrics.Mean(\n            name=keys.LOSS_REGULARIZATION)\n      for i, threshold in enumerate(self._thresholds):\n        eval_metrics[self._accuracy_keys[i]] = tf_keras.metrics.BinaryAccuracy(\n            name=self._accuracy_keys[i], threshold=threshold)\n        eval_metrics[self._precision_keys[i]] = tf_keras.metrics.Precision(\n            name=self._precision_keys[i], thresholds=threshold)\n        eval_metrics[self._recall_keys[i]] = tf_keras.metrics.Recall(\n            name=self._recall_keys[i], thresholds=threshold)\n    return eval_metrics\n\n  def _update_accuracy_baseline(self, eval_metrics):\n    \"\"\"Update accuracy baseline metric based on labels mean metric.\n\n    This is the best the model could do by always predicting one class.\n\n    For example, suppose the labels = [0, 1, 0, 1, 1]. So the\n    label_mean.total = 3, label_mean.count = 5, and\n    label_mean = label_mean.total / label_mean.count = 3 / 5 = 0.6\n    By always predicting one class, there are two cases:\n    (1) predicted_labels_0 = [0, 0, 0, 0, 0], accuracy_0 = 2 / 5 = 0.4\n    (2) predicted_labels_1 = [1, 1, 1, 1, 1], accuracy_1 = 3 / 5 = 0.6\n    So the accuracy_baseline = max(accuracy_0, accuracy_1) = 0.6,\n                             = max(label_mean, 1 - label_mean)\n\n    To update the total and count of accuracy_baseline,\n    accuracy_baseline = max(label_mean, 1 - label_mean)\n                      = max(label_mean.total / label_mean.count,\n                            1 - label_mean.total / label_mean.count)\n                      = max(label_mean.total / label_mean.count,\n                      (label_mean.count - label_mean.total) / label_mean.count)\n    So accuracy_baseline.total = max(label_mean.total,\n                                    (label_mean.count - label_mean.total))\n    accuracy_baseline.count = label_mean.count\n\n    Args:\n      eval_metrics: A `dict` of metrics to be updated.\n    \"\"\"\n    label_mean_metric = eval_metrics[self._label_mean_key]\n    accuracy_baseline_metric = eval_metrics[self._accuracy_baseline_key]\n    accuracy_baseline_metric.add_update(tf.no_op())\n    accuracy_baseline_metric.total = tf.math.maximum(\n        label_mean_metric.total,\n        label_mean_metric.count - label_mean_metric.total)\n    accuracy_baseline_metric.count = label_mean_metric.count\n\n  def _update_auc(self, auc_metric, labels, predictions, weights=None):\n    predictions = tf.cast(predictions, dtype=tf.dtypes.float32)\n    if weights is not None:\n      weights = tf.compat.v2.__internal__.ops.broadcast_weights(weights, predictions)\n    auc_metric.update_state(\n        y_true=labels, y_pred=predictions, sample_weight=weights)\n\n  def update_metrics(self,\n                     eval_metrics,\n                     features,\n                     logits,\n                     labels,\n                     regularization_losses=None):\n    \"\"\"Updates eval metrics. See `base_head.Head` for details.\"\"\"\n    preds = self.predictions(logits)\n    class_ids = preds[prediction_keys.PredictionKeys.CLASS_IDS]\n    logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n    labels = self._processed_labels(logits, labels)\n    unweighted_loss, weights = self._unweighted_loss_and_weights(\n        logits, labels, features)\n    # Update metrics.\n    eval_metrics[self._loss_mean_key].update_state(\n        values=unweighted_loss, sample_weight=weights)\n    eval_metrics[self._accuracy_key].update_state(\n        y_true=labels, y_pred=class_ids, sample_weight=weights)\n    eval_metrics[self._precision_key].update_state(\n        y_true=labels, y_pred=class_ids, sample_weight=weights)\n    eval_metrics[self._recall_key].update_state(\n        y_true=labels, y_pred=class_ids, sample_weight=weights)\n    logistic_key = prediction_keys.PredictionKeys.LOGISTIC\n    predictions = self.predictions(logits, [logistic_key])\n    logistic = predictions[logistic_key]\n    base_head.update_metric_with_broadcast_weights(\n        eval_metrics[self._prediction_mean_key], logistic, weights)\n    base_head.update_metric_with_broadcast_weights(\n        eval_metrics[self._label_mean_key], labels, weights)\n    self._update_accuracy_baseline(eval_metrics)\n    self._update_auc(\n        auc_metric=eval_metrics[self._auc_key],\n        labels=labels,\n        predictions=logistic,\n        weights=weights)\n    self._update_auc(\n        auc_metric=eval_metrics[self._auc_pr_key],\n        labels=labels,\n        predictions=logistic,\n        weights=weights)\n    if regularization_losses is not None:\n      regularization_loss = tf.math.add_n(regularization_losses)\n      eval_metrics[self._loss_regularization_key].update_state(\n          values=regularization_loss)\n    for i in range(len(self._thresholds)):\n      eval_metrics[self._accuracy_keys[i]].update_state(\n          y_true=labels, y_pred=logistic, sample_weight=weights)\n      eval_metrics[self._precision_keys[i]].update_state(\n          y_true=labels, y_pred=logistic, sample_weight=weights)\n      eval_metrics[self._recall_keys[i]].update_state(\n          y_true=labels, y_pred=logistic, sample_weight=weights)\n    return eval_metrics\n\n  def _create_tpu_estimator_spec(self,\n                                 features,\n                                 mode,\n                                 logits,\n                                 labels=None,\n                                 optimizer=None,\n                                 trainable_variables=None,\n                                 train_op_fn=None,\n                                 update_ops=None,\n                                 regularization_losses=None):\n    \"\"\"Returns an `EstimatorSpec`.\n\n    Args:\n      features: Input `dict` mapping string feature names to `Tensor` or\n        `SparseTensor` objects containing the values for that feature in a\n        minibatch. Often to be used to fetch example-weight tensor.\n      mode: Estimator's `ModeKeys`.\n      logits: Logits `Tensor` with shape `[D0, D1, ... DN, 1]`. For many\n        applications, the shape is `[batch_size, 1]`.\n      labels: Labels integer or string `Tensor` with shape matching `logits`,\n        namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required\n        argument when `mode` equals `TRAIN` or `EVAL`.\n      optimizer: An `tf_keras.optimizers.Optimizer` instance to optimize the\n        loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss,\n        trainable_variables)`, which updates variables to minimize `loss`.\n      trainable_variables: A list or tuple of `Variable` objects to update to\n        minimize `loss`. In Tensorflow 1.x, by default these are the list of\n        variables collected in the graph under the key\n        `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have\n        collections and GraphKeys, trainable_variables need to be passed\n        explicitly here.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns\n        `train_op`. Used if `optimizer` is `None`.\n      update_ops: A list or tuple of update ops to be run at training time. For\n        example, layers such as BatchNormalization create mean and variance\n        update ops that need to be run at training time. In Tensorflow 1.x,\n        these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x\n        doesn't have collections, update_ops need to be passed explicitly here.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses. These losses are\n        usually expressed as a batch average, so for best results users need to\n        set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid\n        scaling errors.\n\n    Returns:\n      `EstimatorSpec`.\n\n    Raises:\n      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN\n        mode, or if both are set.\n    \"\"\"\n    with ops.name_scope(self._name, 'head'):\n      # Predict.\n      pred_keys = prediction_keys.PredictionKeys\n      predictions = self.predictions(logits)\n      if mode == ModeKeys.PREDICT:\n        probabilities = predictions[pred_keys.PROBABILITIES]\n        logistic = predictions[pred_keys.LOGISTIC]\n        classifier_output = base_head.classification_output(\n            scores=probabilities,\n            n_classes=2,\n            label_vocabulary=self._label_vocabulary)\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs={\n                base_head.DEFAULT_SERVING_KEY: classifier_output,\n                base_head.CLASSIFY_SERVING_KEY: classifier_output,\n                base_head.REGRESS_SERVING_KEY:\n                    export_output.RegressionOutput(value=logistic),\n                base_head.PREDICT_SERVING_KEY:\n                    export_output.PredictOutput(predictions)\n            })\n      regularized_training_loss = self.loss(\n          logits=logits,\n          labels=labels,\n          features=features,\n          mode=mode,\n          regularization_losses=regularization_losses)\n      # Eval.\n      if mode == ModeKeys.EVAL:\n        eval_metrics = self.metrics(regularization_losses=regularization_losses)\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.EVAL,\n            predictions=predictions,\n            loss=regularized_training_loss,\n            eval_metrics=base_head.create_eval_metrics_tuple(\n                self.update_metrics, {\n                    'eval_metrics': eval_metrics,\n                    'features': features,\n                    'logits': logits,\n                    'labels': labels,\n                    'regularization_losses': regularization_losses\n                }))\n      # Train.\n      train_op = base_head.create_estimator_spec_train_op(\n          head_name=self._name,\n          optimizer=optimizer,\n          train_op_fn=train_op_fn,\n          update_ops=update_ops,\n          trainable_variables=trainable_variables,\n          regularized_training_loss=regularized_training_loss,\n          loss_reduction=self._loss_reduction)\n    # Create summary.\n    base_head.create_estimator_spec_summary(\n        regularized_training_loss=regularized_training_loss,\n        regularization_losses=regularization_losses,\n        summary_key_fn=self._summary_key)\n    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n        mode=ModeKeys.TRAIN,\n        predictions=predictions,\n        loss=regularized_training_loss,\n        train_op=train_op)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/binary_class_head_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for binary_class_head.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import dnn_testing_utils\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.head import binary_class_head as head_lib\nfrom tensorflow_estimator.python.estimator.head import head_utils as test_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass BinaryClassHeadTest(tf.test.TestCase):\n\n  def test_threshold_too_small(self):\n    with self.assertRaisesRegexp(ValueError, r'thresholds not in \\(0, 1\\)'):\n      head_lib.BinaryClassHead(thresholds=(0., 0.5))\n\n  def test_threshold_too_large(self):\n    with self.assertRaisesRegexp(ValueError, r'thresholds not in \\(0, 1\\)'):\n      head_lib.BinaryClassHead(thresholds=(0.5, 1.))\n\n  def test_invalid_loss_reduction(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'Invalid loss_reduction: invalid_loss_reduction'):\n      head_lib.BinaryClassHead(loss_reduction='invalid_loss_reduction')\n    with self.assertRaisesRegexp(ValueError, r'Invalid loss_reduction: none'):\n      head_lib.BinaryClassHead(loss_reduction=tf.losses.Reduction.NONE)\n\n  def test_loss_fn_arg_labels_missing(self):\n\n    def _loss_fn(logits):\n      del logits  # Unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: labels\\. '\n        r'Given arguments: \\(\\'logits\\',\\)'):\n      head_lib.BinaryClassHead(loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_logits_missing(self):\n\n    def _loss_fn(labels):\n      del labels  # unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: logits\\. '\n        r'Given arguments: \\(\\'labels\\',\\)'):\n      head_lib.BinaryClassHead(loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_features_ok(self):\n\n    def _loss_fn(labels, logits, features):\n      del labels, logits, features  # Unused\n      head_lib.BinaryClassHead(loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_invalid(self):\n\n    def _loss_fn(labels, logits, name=None):\n      del labels, logits, name  # Unused\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'loss_fn has unexpected args: \\[\\'name\\'\\]'):\n      head_lib.BinaryClassHead(loss_fn=_loss_fn)\n\n  def test_invalid_logits_shape(self):\n    head = head_lib.BinaryClassHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 1).\n    logits_2x2 = np.array(((45., 44.), (41., 42.),))\n\n    pred_key = prediction_keys.PredictionKeys.PROBABILITIES\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      preds = head.predictions(logits_2x2, [pred_key])\n      self.evaluate(preds[pred_key])\n    if tf.executing_eagerly():\n      return\n\n    # Dynamic shape only works in Graph mode.\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42.,),))},\n        mode=ModeKeys.PREDICT,\n        logits=logits_placeholder,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.predictions[pred_key].eval({logits_placeholder: logits_2x2})\n\n  def test_invalid_labels_shape(self):\n    head = head_lib.BinaryClassHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Labels and logits should be shape (batch_size, 1).\n    labels_2x2 = np.array(((45., 44.), (41., 42.),))\n    logits_2x1 = np.array(((45.,), (41.,),))\n    features = {'x': np.array(((42.,),))}\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):\n      training_loss = head.loss(\n          logits=logits_2x1,\n          labels=labels_2x2,\n          features=features,\n          mode=ModeKeys.EVAL)\n      self.evaluate(training_loss)\n    if tf.executing_eagerly():\n      return\n\n    # Dynamic shape only works in Graph mode.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.loss(\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        features=features,\n        mode=ModeKeys.EVAL)\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 1\\] \\[labels_shape: \\] \\[2 2\\]'):\n        training_loss.eval({\n            logits_placeholder: logits_2x1,\n            labels_placeholder: labels_2x2\n        })\n\n  def test_incompatible_labels_shape(self):\n    head = head_lib.BinaryClassHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Both logits and labels should be shape (batch_size, 1).\n    values_2x1 = np.array(((0.,), (1.,),))\n    values_3x1 = np.array(((0.,), (1.,), (0.,),))\n    features = {'x': values_2x1}\n\n    # Static shape for eager mode.\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'labels shape'):\n        head.loss(\n            logits=values_2x1,\n            labels=values_3x1,\n            features=features,\n            mode=ModeKeys.EVAL)\n      with self.assertRaisesRegexp(ValueError, 'labels shape'):\n        head.loss(\n            logits=values_3x1,\n            labels=values_2x1,\n            features=features,\n            mode=ModeKeys.EVAL)\n      return\n\n    # Static shape for Graph mode.\n    with self.assertRaisesRegexp(ValueError,\n                                 'logits and labels must have the same shape'):\n      head.loss(\n          logits=values_2x1,\n          labels=values_3x1,\n          features=features,\n          mode=ModeKeys.EVAL)\n    with self.assertRaisesRegexp(ValueError,\n                                 'logits and labels must have the same shape'):\n      head.loss(\n          logits=values_3x1,\n          labels=values_2x1,\n          features=features,\n          mode=ModeKeys.EVAL)\n    # Dynamic shape only works in Graph mode.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.loss(\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        features=features,\n        mode=ModeKeys.EVAL)\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[3 1\\] \\[labels_shape: \\] \\[2 1\\]'):\n        training_loss.eval({\n            labels_placeholder: values_2x1,\n            logits_placeholder: values_3x1\n        })\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 1\\] \\[labels_shape: \\] \\[3 1\\]'):\n        training_loss.eval({\n            labels_placeholder: values_3x1,\n            logits_placeholder: values_2x1\n        })\n\n  def test_predict(self):\n    head = head_lib.BinaryClassHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    logits = [[0.3], [-0.4]]\n    expected_logistics = [[0.574443], [0.401312]]\n    expected_probabilities = [[0.425557, 0.574443], [0.598688, 0.401312]]\n    expected_class_ids = [[1], [0]]\n    expected_classes = [[b'1'], [b'0']]\n    expected_all_class_ids = [[0, 1]] * 2\n    expected_all_classes = [[b'0', b'1']] * 2\n    expected_export_classes = [[b'0', b'1']] * 2\n\n    keys = prediction_keys.PredictionKeys\n    preds = head.predictions(logits, [\n        keys.LOGITS, keys.LOGISTIC, keys.PROBABILITIES, keys.CLASS_IDS,\n        keys.CLASSES, keys.ALL_CLASS_IDS, keys.ALL_CLASSES\n    ])\n    self.assertAllClose(logits, self.evaluate(preds[keys.LOGITS]))\n    self.assertAllClose(expected_logistics, self.evaluate(preds[keys.LOGISTIC]))\n    self.assertAllClose(expected_probabilities,\n                        self.evaluate(preds[keys.PROBABILITIES]))\n    self.assertAllClose(expected_class_ids,\n                        self.evaluate(preds[keys.CLASS_IDS]))\n    self.assertAllClose(expected_all_class_ids,\n                        self.evaluate(preds[keys.ALL_CLASS_IDS]))\n    self.assertAllEqual(expected_classes, self.evaluate(preds[keys.CLASSES]))\n    self.assertAllEqual(expected_all_classes,\n                        self.evaluate(preds[keys.ALL_CLASSES]))\n    if tf.executing_eagerly():\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert spec contains expected tensors.\n    self.assertIsNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNone(spec.train_op)\n    self.assertItemsEqual(('classification', 'regression', 'predict',\n                           test_lib._DEFAULT_SERVING_KEY),\n                          spec.export_outputs.keys())\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(logits,\n                          predictions[prediction_keys.PredictionKeys.LOGITS])\n      self.assertAllClose(expected_logistics,\n                          predictions[prediction_keys.PredictionKeys.LOGISTIC])\n      self.assertAllClose(\n          expected_probabilities,\n          predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n      self.assertAllClose(expected_class_ids,\n                          predictions[prediction_keys.PredictionKeys.CLASS_IDS])\n      self.assertAllEqual(expected_classes,\n                          predictions[prediction_keys.PredictionKeys.CLASSES])\n      self.assertAllClose(\n          expected_all_class_ids,\n          predictions[prediction_keys.PredictionKeys.ALL_CLASS_IDS])\n      self.assertAllEqual(\n          expected_all_classes,\n          predictions[prediction_keys.PredictionKeys.ALL_CLASSES])\n      self.assertAllClose(\n          expected_probabilities,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].scores))\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].classes))\n      self.assertAllClose(expected_logistics,\n                          sess.run(spec.export_outputs['regression'].value))\n\n  def test_predict_with_vocabulary_list(self):\n    head = head_lib.BinaryClassHead(label_vocabulary=['aang', 'iroh'])\n\n    logits = [[1.], [0.]]\n    expected_classes = [[b'iroh'], [b'aang']]\n\n    pred_key = prediction_keys.PredictionKeys.CLASSES\n    if tf.executing_eagerly():\n      preds = head.predictions(logits, [pred_key])\n      self.assertAllEqual(expected_classes, preds[pred_key])\n      return\n    preds = head.predictions(logits, [pred_key])\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllEqual(expected_classes, preds[pred_key].eval())\n\n  def test_eval_create_loss(self):\n    head = head_lib.BinaryClassHead()\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum([0, 41]) / 2 = 20.5\n    expected_training_loss = 20.5\n    # Create loss.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    self.assertAllClose(\n        expected_training_loss,\n        self.evaluate(training_loss),\n        rtol=1e-2,\n        atol=1e-2)\n\n  def test_eval_create_loss_loss_fn(self):\n    \"\"\"Tests head.create_loss for eval mode and custom loss_fn.\"\"\"\n    loss = np.array([[1.], [2.]], dtype=np.float32)\n    logits_input = np.array([[-10.], [10.]], dtype=np.float32)\n    labels_input = np.array([[1], [0]], dtype=np.int64)\n\n    def _loss_fn(labels, logits):\n      check_labels = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(labels, labels_input)), data=[labels])\n      check_logits = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(logits, logits_input)), data=[logits])\n      with tf.control_dependencies([check_labels, check_logits]):\n        return tf.constant(loss)\n\n    head = head_lib.BinaryClassHead(loss_fn=_loss_fn)\n\n    actual_training_loss = head.loss(\n        logits=logits_input,\n        labels=labels_input,\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.EVAL)\n    self.assertAllClose(np.sum(loss) / 2., self.evaluate(actual_training_loss))\n\n  def test_eval_create_loss_loss_fn_wrong_shape(self):\n    \"\"\"Tests custom loss_fn that returns Tensor of unexpected shape.\"\"\"\n    loss = np.array([1., 2.], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      del labels, logits  # Unused\n      return tf.constant(loss)\n\n    head = head_lib.BinaryClassHead(loss_fn=_loss_fn)\n\n    logits = np.array([[-10.], [10.]], dtype=np.float32)\n    labels = np.array([[1], [0]], dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'loss_shape'):\n        head.loss(\n            logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    else:\n      actual_training_loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[loss_fn must return Tensor of shape \\[D0, D1, ... DN, 1\\]\\. \\] '\n          r'\\[logits_shape: \\] \\[2 1\\] \\[loss_shape: \\] \\[2\\]'):\n        with self.cached_session():\n          test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n          actual_training_loss.eval()\n\n  def test_eval_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib.BinaryClassHead()\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.loss(\n          logits=np.array(((45,), (-41,),), dtype=np.float32),\n          labels=None,\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.EVAL)\n\n  def test_eval(self):\n    head = head_lib.BinaryClassHead()\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum(0, 41) / 2 = 41 / 2 = 20.5\n    expected_loss = 20.5\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_loss,\n        keys.ACCURACY: 1. / 2,\n        keys.PRECISION: 1.,\n        keys.RECALL: 1. / 2,\n        keys.PREDICTION_MEAN: 1. / 2,\n        keys.LABEL_MEAN: 2. / 2,\n        keys.ACCURACY_BASELINE: 2. / 2,\n        keys.AUC: 0.,\n        keys.AUC_PR: 1.,\n    }\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      self.assertItemsEqual(expected_metrics.keys(), updated_metrics.keys())\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics})\n      loss = head.loss(labels, logits, features=features, mode=ModeKeys.EVAL)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss)\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(expected_metrics,\n                          {k: value_ops[k].eval() for k in value_ops})\n\n  def test_eval_metric_ops_with_head_name(self):\n    head = head_lib.BinaryClassHead(name='some_binary_head')\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    keys = metric_keys.MetricKeys\n    expected_metric_keys = [\n        '{}/some_binary_head'.format(keys.LOSS_MEAN),\n        '{}/some_binary_head'.format(keys.ACCURACY),\n        '{}/some_binary_head'.format(keys.PRECISION),\n        '{}/some_binary_head'.format(keys.RECALL),\n        '{}/some_binary_head'.format(keys.PREDICTION_MEAN),\n        '{}/some_binary_head'.format(keys.LABEL_MEAN),\n        '{}/some_binary_head'.format(keys.ACCURACY_BASELINE),\n        '{}/some_binary_head'.format(keys.AUC),\n        '{}/some_binary_head'.format(keys.AUC_PR),\n    ]\n    eval_metrics = head.metrics()\n    updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                          labels)\n    self.assertItemsEqual(expected_metric_keys, updated_metrics.keys())\n\n  def test_eval_with_regularization_losses(self):\n    head = head_lib.BinaryClassHead()\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size\n    #                    = sum(0, 41) / 2 = 20.5\n    expected_unregularized_loss = 20.5\n    expected_regularized_loss = (\n        expected_unregularized_loss + expected_regularization_loss)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_unregularized_loss,\n        keys.LOSS_REGULARIZATION: expected_regularization_loss,\n        keys.ACCURACY: 1. / 2,\n        keys.PRECISION: 1.,\n        keys.RECALL: 1. / 2,\n        keys.PREDICTION_MEAN: 1. / 2,\n        keys.LABEL_MEAN: 2. / 2,\n        keys.ACCURACY_BASELINE: 2. / 2,\n        keys.AUC: 0.,\n        keys.AUC_PR: 1.,\n    }\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics(regularization_losses=regularization_losses)\n      updated_metrics = head.update_metrics(\n          eval_metrics,\n          features,\n          logits,\n          labels,\n          regularization_losses=regularization_losses)\n      # Assert metrics.\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics})\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        regularization_losses=regularization_losses,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_regularized_loss, loss)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(expected_metrics,\n                          {k: value_ops[k].eval() for k in value_ops})\n\n  def test_eval_with_vocabulary_list_create_loss(self):\n    head = head_lib.BinaryClassHead(label_vocabulary=['aang', 'iroh'])\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum([0, 41]) / 2 = 20.5\n    expected_training_loss = 20.5\n    # Create loss.\n    if tf.executing_eagerly():\n      training_loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n      self.assertAllClose(\n          expected_training_loss, training_loss, rtol=1e-2, atol=1e-2)\n      return\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_training_loss, training_loss.eval())\n\n  def test_eval_with_vocabulary_list(self):\n    head = head_lib.BinaryClassHead(label_vocabulary=['aang', 'iroh'])\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    accuracy_key = metric_keys.MetricKeys.ACCURACY\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      self.assertAllClose(1. / 2, updated_metrics[accuracy_key].result())\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      sess.run(update_ops)\n      self.assertAllClose(1. / 2, value_ops[accuracy_key].eval())\n\n  def test_eval_with_thresholds_create_loss(self):\n    thresholds = [0.25, 0.5, 0.75]\n    head = head_lib.BinaryClassHead(thresholds=thresholds)\n    logits = np.array(((-1,), (1,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # probabilities[i] = 1/(1 + exp(-logits[i])) =>\n    # probabilities = [1/(1 + exp(1)), 1/(1 + exp(-1))] = [0.269, 0.731]\n    # unreduced_loss = -ln(probabilities[label[i]])) = [-ln(0.269), -ln(0.731)]\n    #      = [1.31304389, 0.31334182]\n    # weighted sum loss = 1.62638571\n    # loss = 0.813192855\n    expected_training_loss = 0.813192855\n    # Create loss.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    self.assertAllClose(\n        expected_training_loss,\n        self.evaluate(training_loss),\n        rtol=1e-2,\n        atol=1e-2)\n\n  def test_eval_with_thresholds(self):\n    thresholds = [0.25, 0.5, 0.75]\n    head = head_lib.BinaryClassHead(thresholds=thresholds)\n    logits = np.array(((-1,), (1,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # probabilities[i] = 1/(1 + exp(-logits[i])) =>\n    # probabilities = [1/(1 + exp(1)), 1/(1 + exp(-1))] = [0.269, 0.731]\n    # loss = -sum(ln(probabilities[label[i]])) / batch_size\n    #      = (-ln(0.269) -ln(0.731)) / 2\n    #      = 1.62652338 / 2\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: 1.62652338 / 2.,\n        keys.ACCURACY: 1. / 2,\n        keys.PRECISION: 1.,\n        keys.RECALL: .5,\n        keys.PREDICTION_MEAN: 1. / 2,\n        keys.LABEL_MEAN: 2. / 2,\n        keys.ACCURACY_BASELINE: 2. / 2,\n        keys.AUC: 0.,\n        keys.AUC_PR: 1.,\n        keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1.,\n        keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1.,\n        keys.RECALL_AT_THRESHOLD % thresholds[0]: 1.,\n        keys.ACCURACY_AT_THRESHOLD % thresholds[1]: .5,\n        keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1.,\n        keys.RECALL_AT_THRESHOLD % thresholds[1]: .5,\n        keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 0.,\n        keys.PRECISION_AT_THRESHOLD % thresholds[2]: 0.,\n        keys.RECALL_AT_THRESHOLD % thresholds[2]: 0.,\n    }\n    tol = 1e-2\n    if tf.executing_eagerly():\n      # Create loss.\n      training_loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n      self.assertAllClose(1.62652338 / 2., self.evaluate(training_loss))\n      # Eval metrics.\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      # Assert metrics.\n      self.assertItemsEqual(expected_metrics.keys(), updated_metrics.keys())\n      self.assertAllClose(\n          expected_metrics, {\n              k: self.evaluate(updated_metrics[k].result())\n              for k in updated_metrics\n          },\n          atol=tol,\n          rtol=tol)\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(1.62652338 / 2., loss)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          atol=tol,\n          rtol=tol)\n\n  def test_train_create_loss(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # unreduced_loss = cross_entropy(labels, logits) = [0, 41]\n    # weights default to 1.\n    # training loss = (1 * 0 + 1 * 41) / 2 = 20.5\n    expected_training_loss = 20.5\n    # Create loss.\n    if tf.executing_eagerly():\n      training_loss = head.loss(labels, logits, features)\n      self.assertAllClose(expected_training_loss, training_loss)\n      return\n\n    training_loss = head.loss(labels, logits, features)\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_training_loss, training_loss.eval())\n\n  def test_train_create_loss_loss_reduction(self):\n    \"\"\"Tests create_loss with loss_reduction.\"\"\"\n    head = head_lib.BinaryClassHead(loss_reduction=tf.losses.Reduction.SUM)\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # unreduced_loss = cross_entropy(labels, logits) = [0, 41]\n    # weights default to 1.\n    # training loss = (1 * 0 + 1 * 41)\n    expected_training_loss = 41.\n    # Create loss.\n    if tf.executing_eagerly():\n      training_loss = head.loss(labels, logits, features)\n      self.assertAllClose(expected_training_loss, training_loss)\n      return\n\n    training_loss = head.loss(labels, logits, features)\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_training_loss, training_loss.eval())\n\n  def test_train_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib.BinaryClassHead()\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.loss(\n          logits=np.array(((45,), (-41,),), dtype=np.float32),\n          labels=None,\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN)\n\n  def test_train(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum(0, 41) / 2 = 41 / 2 = 20.5\n    expected_loss = 20.5\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss)\n      return\n\n    expected_train_result = b'my_train_op'\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str)\n\n  def test_train_one_dim_create_loss(self):\n    \"\"\"Tests create_loss with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)\n    labels_rank_1 = np.array((1., 1., 0.,))\n    weights_rank_1 = np.array(((1., .1, 1.5,)), dtype=np.float64)\n    features = {\n        'x': np.array(((42.,), (43.,), (44.,)), dtype=np.float32),\n        'label_weights': weights_rank_1,\n    }\n    # unreduced_loss = cross_entropy(labels, logits) = [0, 41, 44]\n    # weights are reshaped to [3, 1] to match logits.\n    # training loss = (1 * 0 + .1 * 41 + 1.5 * 44) / 3 = 23.366666667\n    expected_training_loss = 23.366666667\n    # Create loss.\n    if tf.executing_eagerly():\n      training_loss = head.loss(labels_rank_1, logits, features)\n      self.assertAllClose(expected_training_loss, training_loss)\n      return\n\n    training_loss = head.loss(labels_rank_1, logits, features)\n    self.assertAllClose(expected_training_loss, self.evaluate(training_loss))\n\n  def test_train_one_dim(self):\n    \"\"\"Tests train with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)\n    labels_rank_1 = np.array((1., 1., 0.,))\n    weights_rank_1 = np.array(((1., .1, 1.5,)), dtype=np.float64)\n    self.assertEqual((3,), labels_rank_1.shape)\n    self.assertEqual((3,), weights_rank_1.shape)\n    features = {\n        'x': np.array(((42.,), (43.,), (44.,)), dtype=np.float32),\n        'label_weights': weights_rank_1,\n    }\n    # losses = label_weights*cross_entropy(labels, logits)\n    #        = (1*0 + .1*41 + 1.5*44) = (1, 4.1, 66)\n    # loss = sum(losses) / batch_size = (1 + 4.1 + 66) / 3 = 23.366666667\n    expected_loss = 23.366666667\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits,\n          labels=labels_rank_1,\n          features=features,\n          mode=ModeKeys.TRAIN)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss)\n      return\n\n    expected_train_result = b'my_train_op'\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertIsNotNone(spec.train_op)\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str)\n\n  def test_train_with_regularization_losses(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size\n    #                    = sum(0, 41) / 2 = 20.5\n    # loss = unregularized_loss + regularization_loss = 22.5\n    expected_loss = 22.5\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits,\n          labels=labels,\n          features=features,\n          mode=ModeKeys.TRAIN,\n          regularization_losses=regularization_losses)\n      self.assertAllClose(expected_loss, loss)\n      return\n\n    expected_train_result = b'my_train_op'\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        regularization_losses=regularization_losses,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      test_lib._assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS:\n                  expected_loss,\n              metric_keys.MetricKeys.LOSS_REGULARIZATION:\n                  (expected_regularization_loss),\n          }, summary_str)\n\n  def test_float_labels_invalid_values(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[1.2], [0.4]], dtype=np.float32)\n    features = {'x': np.array([[42]], dtype=np.float32)}\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, r'Labels must be <= 2 - 1'):\n        head.loss(\n            logits=logits,\n            labels=labels,\n            features=features,\n            mode=ModeKeys.TRAIN)\n      return\n\n    with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,\n                                 r'Labels must be <= n_classes - 1'):\n      training_loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n\n  def test_float_labels_train_create_loss(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[0.8], [0.4]], dtype=np.float32)\n    features = {'x': np.array([[42]], dtype=np.float32)}\n    # loss = cross_entropy(labels, logits)\n    #      = -label[i]*sigmoid(logit[i]) -(1-label[i])*sigmoid(-logit[i])\n    #      = [-0.8 * log(sigmoid(0.5)) -0.2 * log(sigmoid(-0.5)),\n    #         -0.4 * log(sigmoid(-0.3)) -0.6 * log(sigmoid(0.3))]\n    #      = [0.57407698418, 0.67435524446]\n    # weighted_sum_loss = 0.57407698418 + 0.67435524446\n    # training_loss = weighted_sum_loss / 2 = 0.62421611432\n    expected_training_loss = 0.62421611432\n    # Create loss.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertAllClose(expected_training_loss, self.evaluate(training_loss))\n\n  def test_float_labels_train(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[0.8], [0.4]], dtype=np.float32)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array([[42]], dtype=np.float32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum(-label[i]*sigmoid(logit[i]) -(1-label[i])*sigmoid(-logit[i])\n    #           ) / batch_size\n    #      = -0.8 * log(sigmoid(0.5)) -0.2 * log(sigmoid(-0.5)) / 2\n    #        -0.4 * log(sigmoid(-0.3)) -0.6 * log(sigmoid(0.3)) / 2\n    #      = 1.2484322 / 2 = 0.6242161\n    expected_loss = 0.6242161\n    # Create loss.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertAlmostEqual(\n        expected_loss, self.evaluate(training_loss), delta=1.e-5)\n    if tf.executing_eagerly():\n      return\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((dnn_testing_utils.assert_close(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32)),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAlmostEqual(expected_loss, loss, delta=1.e-5)\n      self.assertEqual(expected_train_result, train_result)\n\n  def test_float_labels_eval_create_loss(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[0.8], [0.4]], dtype=np.float32)\n    features = {'x': np.array([[42]], dtype=np.float32)}\n    # unreduced_loss = cross_entropy(labels, logits)\n    #      = -label[i]*sigmoid(logit[i]) -(1-label[i])*sigmoid(-logit[i])\n    #      = [-0.8 * log(sigmoid(0.5)) -0.2 * log(sigmoid(-0.5)),\n    #         -0.4 * log(sigmoid(-0.3)) -0.6 * log(sigmoid(0.3))]\n    #      = [0.57407698418, 0.67435524446]\n    # weighted_sum_loss = 0.57407698418 + 0.67435524446\n    # loss = weighted_sum_loss / batch_size = 1.24843222864 / 2 = 0.62421611432\n    expected_training_loss = 0.62421611432\n    # Create loss.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    self.assertAllClose(\n        expected_training_loss,\n        self.evaluate(training_loss),\n        rtol=1e-2,\n        atol=1e-2)\n\n  def test_float_labels_eval(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array([[0.5], [-0.3]], dtype=np.float32)\n    labels = np.array([[0.8], [0.4]], dtype=np.float32)\n    features = {'x': np.array([[42]], dtype=np.float32)}\n\n    # loss_sum = sum(cross_entropy(labels, logits))\n    #      = sum(-label[i]*sigmoid(logit[i]) -(1-label[i])*sigmoid(-logit[i]))\n    #      = -0.8 * log(sigmoid(0.5)) -0.2 * log(sigmoid(-0.5))\n    #        -0.4 * log(sigmoid(-0.3)) -0.6 * log(sigmoid(0.3))\n    #      = 1.2484322\n    # loss = loss_sum / batch_size = 1.2484322 / 2 = 0.6242161\n    expected_loss = 0.6242161\n    # Create loss.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    self.assertAlmostEqual(\n        expected_loss, self.evaluate(training_loss), delta=1.e-5)\n    # Eval metrics.\n    loss_mean_key = metric_keys.MetricKeys.LOSS_MEAN\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      self.assertAlmostEqual(expected_loss,\n                             updated_metrics[loss_mean_key].result().numpy())\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAlmostEqual(expected_loss, loss, delta=1.e-5)\n      self.assertAlmostEqual(expected_loss, value_ops[loss_mean_key].eval())\n\n  def test_weighted_multi_example_predict(self):\n    \"\"\"3 examples, 1 batch.\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='label_weights')\n    # Create estimator spec.\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.int32)\n    pred_keys = prediction_keys.PredictionKeys\n    keys = [\n        pred_keys.LOGITS, pred_keys.LOGISTIC, pred_keys.PROBABILITIES,\n        pred_keys.CLASS_IDS, pred_keys.CLASSES\n    ]\n    predictions = head.predictions(logits, keys)\n    self.assertAllClose(\n        logits.astype(np.float32), self.evaluate(predictions[pred_keys.LOGITS]))\n    self.assertAllClose(\n        tf.math.sigmoid(logits.astype(np.float32)),\n        self.evaluate(predictions[pred_keys.LOGISTIC]))\n    self.assertAllClose([[0., 1.], [1., 0.], [0., 1.]],\n                        self.evaluate(predictions[pred_keys.PROBABILITIES]))\n    self.assertAllClose([[1], [0], [1]],\n                        self.evaluate(predictions[pred_keys.CLASS_IDS]))\n    self.assertAllEqual([[b'1'], [b'0'], [b'1']],\n                        self.evaluate(predictions[pred_keys.CLASSES]))\n\n  def test_weighted_multi_example_eval(self):\n    \"\"\"3 examples, 1 batch.\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='label_weights')\n\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.int32)\n    labels = np.array(((1,), (1,), (0,)), dtype=np.int32)\n    features = {\n        'x': np.array(((42,), (43,), (44,)), dtype=np.int32),\n        'label_weights': np.array(((1.,), (.1,), (1.5,)), dtype=np.float32)\n    }\n    # label_mean = (1*1 + .1*1 + 1.5*0)/(1 + .1 + 1.5) = 1.1/2.6\n    #            = .42307692307\n    expected_label_mean = .42307692307\n    # losses = label_weights*cross_entropy(labels, logits)\n    #        = (1*0 + .1*41 + 1.5*44) = (1, 4.1, 66)\n    # loss = sum(losses) / batch_size = (1 + 4.1 + 66) / 3 = 70.1 / 3 = 23.36667\n    expected_loss = 23.366666667\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # loss_mean = loss/sum(label_weights) = 70.1/(1 + .1 + 1.5)\n        #           = 70.1/2.6 = 26.9615384615\n        keys.LOSS_MEAN: 26.9615384615,\n        # accuracy = (1*1 + .1*0 + 1.5*0)/(1 + .1 + 1.5) = 1/2.6 = .38461538461\n        keys.ACCURACY: .38461538461,\n        keys.PRECISION: 1. / 2.5,\n        keys.RECALL: 1. / 1.1,\n        # prediction_mean = (1*1 + .1*0 + 1.5*1)/(1 + .1 + 1.5) = 2.5/2.6\n        #                 = .96153846153\n        keys.PREDICTION_MEAN: .96153846153,\n        keys.LABEL_MEAN: expected_label_mean,\n        keys.ACCURACY_BASELINE: 1 - expected_label_mean,\n        keys.AUC: .45454565,\n        keys.AUC_PR: .4010627,\n    }\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      self.assertItemsEqual(expected_metrics.keys(), updated_metrics.keys())\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics})\n      loss = head.loss(labels, logits, features=features, mode=ModeKeys.EVAL)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss)\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(expected_metrics,\n                          {k: value_ops[k].eval() for k in value_ops})\n\n  def test_weighted_multi_example_train(self):\n    \"\"\"3 examples, 1 batch.\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array(((45,), (-41,), (44,)), dtype=np.float32)\n    features = {\n        'x': np.array(((42.,), (43.,), (44.,)), dtype=np.float32),\n        'label_weights': np.array(((1.,), (.1,), (1.5,)), dtype=np.float64),\n    }\n    labels = np.array(((1.,), (1.,), (0.,)))\n    expected_train_result = b'my_train_op'\n    # losses = label_weights*cross_entropy(labels, logits)\n    #        = (1*0 + .1*41 + 1.5*44) = (1, 4.1, 66)\n    # loss = sum(losses) / batch_size = (1 + 4.1 + 66) / 3 = 23.366666667\n    expected_loss = 23.366666667\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss)\n      return\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertIsNotNone(spec.train_op)\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      test_lib._assert_simple_summaries(\n          self, {metric_keys.MetricKeys.LOSS: expected_loss}, summary_str)\n\n  def test_multi_dim_weighted_train_create_loss(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='weights')\n\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    features = {'weights': weights}\n    # unreduced_loss = cross_entropy(labels, logits) = [[10, 0], [0, 12]].\n    # Weights are reshaped to [2, 2, 1] to match logits.\n    # training_loss = (1*10 + 1.5*0 + 2*0 + 2.5*12) / 2*2 = 40 / 4 = 10\n    expected_training_loss = 10.\n    tol = 1e-2\n    # Create loss.\n    if tf.executing_eagerly():\n      training_loss = head.loss(labels, logits, features, mode=ModeKeys.TRAIN)\n      self.assertAllClose(\n          expected_training_loss, training_loss, rtol=tol, atol=tol)\n      return\n\n    training_loss = head.loss(labels, logits, features)\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n\n  def test_multi_dim_weighted_train(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='weights')\n\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    features = {'weights': weights}\n    # losses = cross_entropy(labels, logits) = [[10, 0], [0, 12]].\n    # weighted_sum_loss = 1*10 + 1.5*0 + 2*0 + 2.5*12 = 40\n    # loss = weighted_sum_loss / batch_size = 40 / (2*2) = 10\n    expected_loss = 10.\n    tol = 1e-2\n    # Create loss.\n    if tf.executing_eagerly():\n      training_loss = head.loss(labels, logits, features, mode=ModeKeys.TRAIN)\n      self.assertAllClose(expected_loss, training_loss, rtol=tol, atol=tol)\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n\n  def test_multi_dim_train_weights_wrong_inner_dim(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 1].\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='weights')\n\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights = np.array([[1.], [2.]], dtype=np.float32)\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'weights shape'):\n        head.loss(\n            logits=logits,\n            labels=labels,\n            features={'weights': weights},\n            mode=ModeKeys.TRAIN)\n      return\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\] \\[2 2 1\\] \\[weights_shape: \\] \\[2 1\\]'):\n        spec.loss.eval()\n\n  def test_multi_dim_train_weights_wrong_outer_dim(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 2, 2].\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='weights')\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights = np.array([[[1., 1.1], [1.5, 1.6]], [[2., 2.1], [2.5, 2.6]]])\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'weights shape'):\n        head.loss(\n            logits=logits,\n            labels=labels,\n            features={'weights': weights},\n            mode=ModeKeys.TRAIN)\n      return\n    weights_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights_placeholder},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\]\\s\\[2 2 1\\]\\s\\[weights_shape: \\]\\s\\[2 2 2\\]'):\n        spec.loss.eval({weights_placeholder: weights})\n\n  def test_multi_dim_weighted_eval(self):\n    \"\"\"Logits and labels of shape [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib.BinaryClassHead(weight_column='weights')\n\n    logits = np.array([[[10], [-10]], [[12], [-12]]], dtype=np.float32)\n    labels = np.array([[[0], [0]], [[1], [1]]], dtype=np.float64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    # losses = cross_entropy(labels, logits) = [[10, 0], [0, 12]].\n    # weighted_sum_loss = 1*10 + 1.5*0 + 2*0 + 2.5*12 = 40\n    # loss = weighted_sum_loss / batch_size = 40 / (2*2) = 10.\n    weighted_sum_loss = 40.\n    expected_loss = 10.\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: weighted_sum_loss / np.sum(weights),\n        keys.ACCURACY: (1. * 0. + 1.5 * 1. + 2. * 1. + 2.5 * 0.) / np.sum(weights),\n        keys.PRECISION: 2.0 / 3.0,\n        keys.RECALL: 2.0 / 4.5,\n        keys.PREDICTION_MEAN:\n            (1. * 1 + 1.5 * 0 + 2. * 1 + 2.5 * 0) / np.sum(weights),\n        keys.LABEL_MEAN:\n            (1. * 0 + 1.5 * 0 + 2. * 1 + 2.5 * 1) / np.sum(weights),\n        keys.ACCURACY_BASELINE:\n            (1. * 0 + 1.5 * 0 + 2. * 1 + 2.5 * 1) / np.sum(weights),\n        keys.AUC: 0.5222,\n        keys.AUC_PR: 0.6582,\n    }\n    tol = 1e-2\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits,\n          labels=labels,\n          features={'weights': weights},\n          mode=ModeKeys.TRAIN)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(\n          eval_metrics,\n          features={'weights': weights},\n          logits=logits,\n          labels=labels)\n      # Assert metrics.\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics},\n          rtol=tol,\n          atol=tol)\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n\n@test_util.deprecated_graph_mode_only\nclass BinaryClassHeadForEstimator(tf.test.TestCase):\n  \"\"\"Tests for create_estimator_spec running in Graph mode only.\"\"\"\n\n  def test_invalid_trainable_variables(self):\n    head = head_lib.BinaryClassHead()\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params\n        return [\n            tf.strings.join([\n                tf.constant('my_train_op'),\n                tf.strings.as_string(loss, precision=2)\n            ])\n        ]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'trainable_variables cannot be None'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (-41,),), dtype=np.float32),\n          labels=np.array(((1,), (1,),), dtype=np.float64),\n          optimizer=_Optimizer('my_optimizer'),\n          trainable_variables=None)\n    with self.assertRaisesRegexp(\n        ValueError, r'trainable_variables should be a list or a tuple'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (-41,),), dtype=np.float32),\n          labels=np.array(((1,), (1,),), dtype=np.float64),\n          optimizer=_Optimizer('my_optimizer'),\n          trainable_variables={\n              'var_list': [tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)]\n          })\n\n  def test_train_with_optimizer(self):\n    head = head_lib.BinaryClassHead()\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum(0, 41) / 2 = 41 / 2 = 20.5\n    expected_loss = 20.5\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params\n        with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n            tf.cast(expected_loss, dtype=tf.dtypes.float32),\n            tf.cast(loss, dtype=tf.dtypes.float32),\n            name='assert_loss'),)):\n          return [tf.constant(expected_train_result)]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        optimizer=_Optimizer('my_optimizer'),\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n\n  def test_train_with_update_ops(self):\n    with tf.Graph().as_default():\n      w = tf.Variable(1)\n      update_op = w.assign_add(1)\n\n      t = tf.Variable('')\n      expected_train_result = b'my_train_op'\n\n      def _train_op_fn(loss):\n        del loss\n        return t.assign(expected_train_result)\n\n      head = head_lib.BinaryClassHead()\n      spec = head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (-41,),), dtype=np.float32),\n          labels=np.array(((1,), (1,),), dtype=np.float64),\n          train_op_fn=_train_op_fn,\n          update_ops=[update_op],\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n\n      with self.cached_session() as sess:\n        test_lib._initialize_variables(self, spec.scaffold)\n        sess.run(spec.train_op)\n        w_value, t_value = sess.run([w, t])\n        self.assertEqual(2, w_value)\n        self.assertEqual(expected_train_result, t_value)\n\n  def test_train_summaries_with_head_name(self):\n    head = head_lib.BinaryClassHead(name='some_binary_head')\n\n    logits = np.array(((45,), (-41,),), dtype=np.float32)\n    labels = np.array(((1,), (1,),), dtype=np.float64)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum(0, 41) / 2 = 20.5\n    expected_loss = 20.5\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      summary_str = sess.run(spec.scaffold.summary_op)\n      test_lib._assert_simple_summaries(self, {\n          '{}/some_binary_head'.format(metric_keys.MetricKeys.LOSS):\n              expected_loss,\n      }, summary_str)\n\n  def test_lookup_tables_in_graph(self):\n    head = head_lib.BinaryClassHead(label_vocabulary=['aang', 'iroh'])\n\n    feature_columns = [tf.feature_column.numeric_column('x')]\n    est = dnn.DNNEstimatorV2(\n        head=head,\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        batch_norm=True)\n\n    def input_fn():\n      return ({\n          'x': np.array(((42,), (43,),), dtype=np.int32)\n      }, [[b'iroh'], [b'iroh']])\n\n    # Train.\n    num_steps = 1\n    est.train(input_fn, steps=num_steps)\n    # Eval.\n    eval_results = est.evaluate(input_fn, steps=num_steps)\n    self.assertEqual(num_steps,\n                     eval_results[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn(metric_keys.MetricKeys.LOSS_MEAN, six.iterkeys(eval_results))\n    # Predict.\n    est.predict(input_fn)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/head_utils.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utilities for heads and unit tests.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.head import binary_class_head\nfrom tensorflow_estimator.python.estimator.head import multi_class_head\n\n_DEFAULT_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n\n\ndef binary_or_multi_class_head(n_classes, weight_column, label_vocabulary,\n                               loss_reduction):\n  \"\"\"Creates either binary or multi-class head.\n\n  Args:\n    n_classes: Number of label classes.\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example. If it is a string, it is\n      used as a key to fetch weight tensor from the `features`. If it is a\n      `NumericColumn`, raw tensor is fetched by key `weight_column.key`, then\n      weight_column.normalizer_fn is applied on it to get weight tensor.\n    label_vocabulary: A list of strings represents possible label values. If\n      given, labels must be string type and have any value in\n      `label_vocabulary`. If it is not given, that means labels are already\n      encoded as integer or float within [0, 1] for `n_classes=2` and encoded as\n      integer values in {0, 1,..., n_classes-1} for `n_classes`>2 . Also there\n      will be errors if vocabulary is not provided and labels are string.\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Defines how to\n      reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`.\n\n  Returns:\n    A `Head` instance.\n  \"\"\"\n  if n_classes == 2:\n    head = binary_class_head.BinaryClassHead(\n        weight_column=weight_column,\n        label_vocabulary=label_vocabulary,\n        loss_reduction=loss_reduction)\n  else:\n    head = multi_class_head.MultiClassHead(\n        n_classes,\n        weight_column=weight_column,\n        label_vocabulary=label_vocabulary,\n        loss_reduction=loss_reduction)\n  return head\n\n\ndef _initialize_variables(test_case, scaffold):\n  scaffold.finalize()\n  test_case.assertIsNone(scaffold.init_feed_dict)\n  test_case.assertIsNone(scaffold.init_fn)\n  scaffold.init_op.run()\n  scaffold.ready_for_local_init_op.eval()\n  scaffold.local_init_op.run()\n  scaffold.ready_op.eval()\n  test_case.assertIsNotNone(scaffold.saver)\n\n\ndef _assert_simple_summaries(test_case,\n                             expected_summaries,\n                             summary_str,\n                             tol=1e-6):\n  \"\"\"Assert summary the specified simple values.\n\n  Args:\n    test_case: test case.\n    expected_summaries: Dict of expected tags and simple values.\n    summary_str: Serialized `summary_pb2.Summary`.\n    tol: Tolerance for relative and absolute.\n  \"\"\"\n  summary = tf.compat.v1.summary.Summary()\n  summary.ParseFromString(summary_str)\n  test_case.assertAllClose(\n      expected_summaries, {v.tag: v.simple_value for v in summary.value},\n      rtol=tol,\n      atol=tol)\n\n\ndef _assert_no_hooks(test_case, spec):\n  test_case.assertAllEqual([], spec.training_chief_hooks)\n  test_case.assertAllEqual([], spec.training_hooks)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/multi_class_head.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Multi class head.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import lookup_ops\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.head import base_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v2\n\n\n@estimator_export('estimator.MultiClassHead')\nclass MultiClassHead(base_head.Head):\n  \"\"\"Creates a `Head` for multi class classification.\n\n  Uses `sparse_softmax_cross_entropy` loss.\n\n  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`.\n  In many applications, the shape is `[batch_size, n_classes]`.\n\n  `labels` must be a dense `Tensor` with shape matching `logits`, namely\n  `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string\n  `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,\n  `labels` must be an integer `Tensor` with values specifying the class index.\n\n  If `weight_column` is specified, weights must be of shape\n  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.\n\n  The loss is the weighted sum over the input dimensions. Namely, if the input\n  labels have shape `[batch_size, 1]`, the loss is the weighted sum over\n  `batch_size`.\n\n  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or\n  `(labels, logits, features, loss_reduction)` as arguments and returns\n  unreduced loss with shape `[D0, D1, ... DN, 1]`. `loss_fn` must support\n  integer `labels` with shape `[D0, D1, ... DN, 1]`. Namely, the head applies\n  `label_vocabulary` to the input labels before passing them to `loss_fn`.\n\n  Usage:\n\n  >>> n_classes = 3\n  >>> head = tf.estimator.MultiClassHead(n_classes)\n  >>> logits = np.array(((10, 0, 0), (0, 10, 0),), dtype=np.float32)\n  >>> labels = np.array(((1,), (1,)), dtype=np.int64)\n  >>> features = {'x': np.array(((42,),), dtype=np.int32)}\n  >>> # expected_loss = sum(cross_entropy(labels, logits)) / batch_size\n  >>> #               = sum(10, 0) / 2 = 5.\n  >>> loss = head.loss(labels, logits, features=features)\n  >>> print('{:.2f}'.format(loss.numpy()))\n  5.00\n  >>> eval_metrics = head.metrics()\n  >>> updated_metrics = head.update_metrics(\n  ...   eval_metrics, features, logits, labels)\n  >>> for k in sorted(updated_metrics):\n  ...   print('{} : {:.2f}'.format(k, updated_metrics[k].result().numpy()))\n  accuracy : 0.50\n  average_loss : 5.00\n  >>> preds = head.predictions(logits)\n  >>> print(preds['logits'])\n  tf.Tensor(\n    [[10.  0.  0.]\n     [ 0. 10.  0.]], shape=(2, 3), dtype=float32)\n\n  Usage with a canned estimator:\n\n  ```python\n  my_head = tf.estimator.MultiClassHead(n_classes=3)\n  my_estimator = tf.estimator.DNNEstimator(\n      head=my_head,\n      hidden_units=...,\n      feature_columns=...)\n  ```\n\n  It can also be used with a custom `model_fn`. Example:\n\n  ```python\n  def _my_model_fn(features, labels, mode):\n    my_head = tf.estimator.MultiClassHead(n_classes=3)\n    logits = tf_keras.Model(...)(features)\n\n    return my_head.create_estimator_spec(\n        features=features,\n        mode=mode,\n        labels=labels,\n        optimizer=tf_keras.optimizers.Adagrad(lr=0.1),\n        logits=logits)\n\n  my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)\n  ```\n\n  Args:\n    n_classes: Number of classes, must be greater than 2 (for 2 classes, use\n      `BinaryClassHead`).\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example.\n    label_vocabulary: A list or tuple of strings representing possible label\n      values. If it is not given, that means labels are already encoded as an\n      integer within [0, n_classes). If given, labels must be of string type and\n      have any value in `label_vocabulary`. Note that errors will be raised if\n      `label_vocabulary` is not provided but labels are strings. If both\n      `n_classes` and `label_vocabulary` are provided, `label_vocabulary` should\n      contain exactly `n_classes` items.\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Decides how to\n      reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely\n      weighted sum of losses divided by `batch size * label_dimension`.\n    loss_fn: Optional loss function.\n    name: Name of the head. If provided, summary and metrics keys will be\n      suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops.\n  \"\"\"\n\n  def __init__(self,\n               n_classes,\n               weight_column=None,\n               label_vocabulary=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               loss_fn=None,\n               name=None):\n    if n_classes is None:\n      raise ValueError('n_classes cannot be None')\n    if label_vocabulary is not None and not isinstance(label_vocabulary,\n                                                       (list, tuple)):\n      raise ValueError(\n          'label_vocabulary should be a list or a tuple. Given type: {}'.format(\n              type(label_vocabulary)))\n    if label_vocabulary is not None and len(label_vocabulary) != n_classes:\n      raise ValueError(\n          '\"label_vocabulary\" does not have \"n_classes\" items. '\n          'len(label_vocabulary)={}, n_classes={}, label_vocabulary={}'.format(\n              len(label_vocabulary), n_classes, label_vocabulary))\n    base_head.validate_loss_reduction(loss_reduction)\n    if loss_fn:\n      base_head.validate_loss_fn_args(loss_fn)\n    self._n_classes = base_head.validate_n_classes(n_classes)\n    self._weight_column = weight_column\n    self._label_vocabulary = label_vocabulary\n    self._loss_reduction = loss_reduction\n    self._loss_fn = loss_fn\n    self._name = name\n    # Metric keys.\n    keys = metric_keys.MetricKeys\n    self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)\n    self._accuracy_key = self._summary_key(keys.ACCURACY)\n    self._loss_regularization_key = self._summary_key(keys.LOSS_REGULARIZATION)\n\n  @property\n  def name(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._name\n\n  @property\n  def logits_dimension(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._n_classes\n\n  @property\n  def loss_reduction(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._loss_reduction\n\n  # Attributes for lookup tables in Eager execution. Note that for Graph\n  # execution, the lookup tables are created on demanded to make sure the\n  # lookup table is in the same graph as its input tensors for `train` and\n  # 'eval' of Estimator (as Estimator recreates graphs for `train`, `eval` and\n  # `predict`).\n  _cached_class_id_table = None\n  _cached_class_string_table = None\n\n  @property\n  def _class_id_table(self):\n    \"\"\"Creates a lookup table for class_id.\n\n    In eager execution, this lookup table will be lazily created on the first\n    call of `self._class_id_table`, and cached for later use; In graph\n    execution, it will be created on demand.\n\n    Returns:\n      A hash table for lookup.\n    \"\"\"\n    if self._cached_class_id_table is None or not tf.executing_eagerly():\n      self._cached_class_id_table = lookup_ops.index_table_from_tensor(\n          vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup')\n    return self._cached_class_id_table\n\n  @property\n  def _class_string_table(self):\n    \"\"\"Creates a lookup table for class_string.\n\n    In eager execution, this lookup table will be lazily created on the first\n    call of `self._class_string_table` and cached for later use; In graph\n    execution, it will be created on demand.\n\n    Returns:\n      A hash table for lookup.\n    \"\"\"\n    if (self._cached_class_string_table is None or not tf.executing_eagerly()):\n      self._cached_class_string_table = (\n          lookup_ops.index_to_string_table_from_tensor(\n              vocabulary_list=self._label_vocabulary,\n              name='class_string_lookup'))\n    return self._cached_class_string_table\n\n  def _processed_labels(self, logits, labels):\n    \"\"\"Converts labels to integer id space.\"\"\"\n    labels = base_head.check_dense_labels_match_logits_and_reshape(\n        labels=labels, logits=logits, expected_labels_dimension=1)\n    if self._label_vocabulary is None:\n      if not labels.dtype.is_integer:\n        raise ValueError(\n            'Labels dtype should be integer. Instead got {}.'.format(\n                labels.dtype))\n      label_ids = labels\n    else:\n      if labels.dtype != tf.dtypes.string:\n        raise ValueError('Labels dtype should be string if there is a '\n                         'vocabulary. Instead got {}'.format(labels.dtype))\n      label_ids = self._class_id_table.lookup(labels)\n    return base_head.check_label_range(label_ids, self._n_classes)\n\n  def _unweighted_loss_and_weights(self, logits, label_ids, features):\n    \"\"\"Computes loss spec.\"\"\"\n    if self._loss_fn:\n      unweighted_loss = base_head.call_loss_fn(\n          loss_fn=self._loss_fn,\n          labels=label_ids,\n          logits=logits,\n          features=features,\n          expected_loss_dim=1)\n    else:\n      unweighted_loss = tf.compat.v1.losses.sparse_softmax_cross_entropy(\n          labels=label_ids,\n          logits=logits,\n          reduction=tf.compat.v1.losses.Reduction.NONE)\n      # Restore the squeezed dim, so unweighted_loss matches the weights shape.\n      unweighted_loss = tf.compat.v1.expand_dims(unweighted_loss, axis=-1)\n    weights = base_head.get_weights_and_check_match_logits(\n        features=features, weight_column=self._weight_column, logits=logits)\n    return unweighted_loss, weights\n\n  def loss(self,\n           labels,\n           logits,\n           features=None,\n           mode=None,\n           regularization_losses=None):\n    \"\"\"Returns regularized training loss. See `base_head.Head` for details.\"\"\"\n    del mode  # Unused for this head.\n    with ops.name_scope(\n        'losses', values=(logits, labels, regularization_losses, features)):\n      logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n      label_ids = self._processed_labels(logits, labels)\n      unweighted_loss, weights = self._unweighted_loss_and_weights(\n          logits, label_ids, features)\n      training_loss = tf_keras_v2.__internal__.losses.compute_weighted_loss(\n          unweighted_loss,\n          sample_weight=weights,\n          reduction=self._loss_reduction)\n      regularization_loss = tf.math.add_n(\n          regularization_losses) if regularization_losses is not None else None\n      regularized_training_loss = (\n          training_loss + regularization_loss\n          if regularization_loss is not None else training_loss)\n    return regularized_training_loss\n\n  def predictions(self, logits, keys=None):\n    \"\"\"Return predictions based on keys.\n\n    See `base_head.Head` for details.\n\n    Args:\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.\n        For many applications, the shape is `[batch_size, logits_dimension]`.\n      keys: a list or tuple of prediction keys. Each key can be either the class\n        variable of prediction_keys.PredictionKeys or its string value, such as:\n          prediction_keys.PredictionKeys.CLASSES or 'classes'. If not specified,\n          it will return the predictions for all valid keys.\n\n    Returns:\n      A dict of predictions.\n    \"\"\"\n    pred_keys = prediction_keys.PredictionKeys\n    valid_keys = [\n        pred_keys.LOGITS, pred_keys.PROBABILITIES, pred_keys.CLASS_IDS,\n        pred_keys.CLASSES, pred_keys.ALL_CLASS_IDS, pred_keys.ALL_CLASSES\n    ]\n    if keys:\n      base_head.check_prediction_keys(keys, valid_keys)\n    else:\n      keys = valid_keys\n    logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n    predictions = {}\n    with ops.name_scope('predictions', values=(logits,)):\n      if pred_keys.LOGITS in keys:\n        predictions[pred_keys.LOGITS] = logits\n      if pred_keys.PROBABILITIES in keys:\n        probabilities = tf.compat.v1.nn.softmax(\n            logits, name=pred_keys.PROBABILITIES)\n        predictions[pred_keys.PROBABILITIES] = probabilities\n      if pred_keys.CLASS_IDS in keys or pred_keys.CLASSES in keys:\n        # class_ids's shape is [D0, D1, ... DN].\n        class_ids = tf.compat.v1.math.argmax(\n            logits, axis=-1, name=pred_keys.CLASS_IDS)\n        # Expand to [batch_size, 1].\n        class_ids = tf.compat.v1.expand_dims(class_ids, axis=-1)\n        if pred_keys.CLASS_IDS in keys:\n          predictions[pred_keys.CLASS_IDS] = class_ids\n        if pred_keys.CLASSES in keys:\n          if self._label_vocabulary:\n            classes = self._class_string_table.lookup(class_ids)\n          else:\n            classes = tf.strings.as_string(class_ids, name='str_classes')\n          predictions[pred_keys.CLASSES] = classes\n      if pred_keys.ALL_CLASS_IDS in keys:\n        predictions[pred_keys.ALL_CLASS_IDS] = base_head.all_class_ids(\n            logits, n_classes=self._n_classes)\n      if pred_keys.ALL_CLASSES in keys:\n        predictions[pred_keys.ALL_CLASSES] = base_head.all_classes(\n            logits,\n            n_classes=self._n_classes,\n            label_vocabulary=self._label_vocabulary)\n      return predictions\n\n  def metrics(self, regularization_losses=None):\n    \"\"\"Creates metrics. See `base_head.Head` for details.\"\"\"\n    keys = metric_keys.MetricKeys\n    with ops.name_scope('metrics', values=(regularization_losses,)):\n      # Mean metric.\n      eval_metrics = {}\n      eval_metrics[self._loss_mean_key] = tf_keras.metrics.Mean(\n          name=keys.LOSS_MEAN)\n      if regularization_losses is not None:\n        eval_metrics[self._loss_regularization_key] = tf_keras.metrics.Mean(\n            name=keys.LOSS_REGULARIZATION)\n      # Accuracy metric.\n      eval_metrics[self._accuracy_key] = tf_keras.metrics.Accuracy(\n          name=keys.ACCURACY)\n    return eval_metrics\n\n  def update_metrics(self,\n                     eval_metrics,\n                     features,\n                     logits,\n                     labels,\n                     regularization_losses=None):\n    \"\"\"Updates eval metrics. See `base_head.Head` for details.\"\"\"\n    preds = self.predictions(logits)\n    class_ids = preds[prediction_keys.PredictionKeys.CLASS_IDS]\n    logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n    label_ids = self._processed_labels(logits, labels)\n    unweighted_loss, weights = self._unweighted_loss_and_weights(\n        logits, label_ids, features)\n\n    # Update metrics.\n    eval_metrics[self._loss_mean_key].update_state(\n        values=unweighted_loss, sample_weight=weights)\n    eval_metrics[self._accuracy_key].update_state(\n        y_true=label_ids, y_pred=class_ids, sample_weight=weights)\n\n    if regularization_losses is not None:\n      regularization_loss = tf.math.add_n(regularization_losses)\n      eval_metrics[self._loss_regularization_key].update_state(\n          values=regularization_loss)\n    return eval_metrics\n\n  def _create_tpu_estimator_spec(self,\n                                 features,\n                                 mode,\n                                 logits,\n                                 labels=None,\n                                 optimizer=None,\n                                 trainable_variables=None,\n                                 train_op_fn=None,\n                                 update_ops=None,\n                                 regularization_losses=None):\n    \"\"\"Returns a `model_fn._TPUEstimatorSpec`.\n\n    Args:\n      features: Input `dict` of `Tensor` or `SparseTensor` objects.\n      mode: Estimator's `ModeKeys`.\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.\n        For many applications, the shape is `[batch_size, logits_dimension]`.\n      labels: Labels integer or string `Tensor` with shape matching `logits`,\n        namely `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN]`. `labels` is required\n        argument when `mode` equals `TRAIN` or `EVAL`.\n      optimizer: An `tf_keras.optimizers.Optimizer` instance to optimize the\n        loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss,\n        trainable_variables)`, which updates variables to minimize `loss`.\n      trainable_variables: A list or tuple of `Variable` objects to update to\n        minimize `loss`. In Tensorflow 1.x, by default these are the list of\n        variables collected in the graph under the key\n        `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have\n        collections and GraphKeys, trainable_variables need to be passed\n        explicitly here.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns\n        `train_op`. Used if `optimizer` is `None`.\n      update_ops: A list or tuple of update ops to be run at training time. For\n        example, layers such as BatchNormalization create mean and variance\n        update ops that need to be run at training time. In Tensorflow 1.x,\n        these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x\n        doesn't have collections, update_ops need to be passed explicitly here.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses. These losses are\n        usually expressed as a batch average, so for best results users need to\n        use the default `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the\n        head to avoid scaling errors.\n\n    Returns:\n      A `model_fn._TPUEstimatorSpec` instance.\n\n    Raises:\n      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN\n        mode, or if both are set.\n    \"\"\"\n    with ops.name_scope(self._name, 'head'):\n      # Predict.\n      pred_keys = prediction_keys.PredictionKeys\n      predictions = self.predictions(logits)\n      if mode == ModeKeys.PREDICT:\n        probabilities = predictions[pred_keys.PROBABILITIES]\n        classifier_output = base_head.classification_output(\n            scores=probabilities,\n            n_classes=self._n_classes,\n            label_vocabulary=self._label_vocabulary)\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs={\n                base_head.DEFAULT_SERVING_KEY:\n                    classifier_output,\n                base_head.CLASSIFY_SERVING_KEY:\n                    classifier_output,\n                base_head.PREDICT_SERVING_KEY:\n                    export_output.PredictOutput(predictions)\n            })\n      regularized_training_loss = self.loss(\n          logits=logits,\n          labels=labels,\n          features=features,\n          mode=mode,\n          regularization_losses=regularization_losses)\n      # Eval.\n      if mode == ModeKeys.EVAL:\n        eval_metrics = self.metrics(regularization_losses=regularization_losses)\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.EVAL,\n            predictions=predictions,\n            loss=regularized_training_loss,\n            eval_metrics=base_head.create_eval_metrics_tuple(\n                self.update_metrics, {\n                    'eval_metrics': eval_metrics,\n                    'features': features,\n                    'logits': logits,\n                    'labels': labels,\n                    'regularization_losses': regularization_losses\n                }))\n      # Train.\n      train_op = base_head.create_estimator_spec_train_op(\n          head_name=self._name,\n          optimizer=optimizer,\n          train_op_fn=train_op_fn,\n          update_ops=update_ops,\n          trainable_variables=trainable_variables,\n          regularized_training_loss=regularized_training_loss,\n          loss_reduction=self._loss_reduction)\n    # Create summary.\n    base_head.create_estimator_spec_summary(\n        regularized_training_loss=regularized_training_loss,\n        regularization_losses=regularization_losses,\n        summary_key_fn=self._summary_key)\n    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n        mode=ModeKeys.TRAIN,\n        predictions=predictions,\n        loss=regularized_training_loss,\n        train_op=train_op)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/multi_class_head_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for multi_class_head.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.head import head_utils as test_lib\nfrom tensorflow_estimator.python.estimator.head import multi_class_head as head_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass MultiClassHead(tf.test.TestCase):\n\n  def test_n_classes_is_none(self):\n    with self.assertRaisesRegexp(ValueError, 'n_classes cannot be None'):\n      head_lib.MultiClassHead(n_classes=None)\n\n  def test_n_classes_is_2(self):\n    with self.assertRaisesRegexp(ValueError, 'n_classes must be > 2'):\n      head_lib.MultiClassHead(n_classes=2)\n\n  def test_invalid_loss_reduction(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'Invalid loss_reduction: invalid_loss_reduction'):\n      head_lib.MultiClassHead(\n          n_classes=3, loss_reduction='invalid_loss_reduction')\n    with self.assertRaisesRegexp(ValueError, r'Invalid loss_reduction: none'):\n      head_lib.MultiClassHead(\n          n_classes=3, loss_reduction=tf.losses.Reduction.NONE)\n\n  def test_loss_fn_arg_labels_missing(self):\n\n    def _loss_fn(logits):\n      del logits  # Unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: labels\\. '\n        r'Given arguments: \\(\\'logits\\',\\)'):\n      head_lib.MultiClassHead(n_classes=3, loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_logits_missing(self):\n\n    def _loss_fn(labels):\n      del labels  # unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: logits\\. '\n        r'Given arguments: \\(\\'labels\\',\\)'):\n      head_lib.MultiClassHead(n_classes=3, loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_features_ok(self):\n\n    def _loss_fn(labels, logits, features):\n      del labels, logits, features  # Unused\n\n    head_lib.MultiClassHead(n_classes=3, loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_invalid(self):\n\n    def _loss_fn(labels, logits, name=None):\n      del labels, logits, name  # Unused\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'loss_fn has unexpected args: \\[\\'name\\'\\]'):\n      head_lib.MultiClassHead(n_classes=3, loss_fn=_loss_fn)\n\n  def test_invalid_logits_shape(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 3).\n    logits_2x2 = np.array((\n        (45., 44.),\n        (41., 42.),\n    ))\n    pred_key = prediction_keys.PredictionKeys.PROBABILITIES\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      preds = head.predictions(logits_2x2, [pred_key])\n      self.evaluate(preds[pred_key])\n    if tf.executing_eagerly():\n      return\n\n    # Dynamic shape only works in Graph mode.\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': np.array((\n            (30.,),\n            (42.,),\n        ))},\n        mode=ModeKeys.PREDICT,\n        logits=logits_placeholder,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.predictions[pred_key].eval({logits_placeholder: logits_2x2})\n\n  def test_invalid_labels_shape(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 3).\n    # Labels should be shape (batch_size, 1).\n    labels_2x2 = np.array((\n        (1, 2),\n        (0, 1),\n    ), dtype=int)\n    logits_2x3 = np.array((\n        (1., 2., 3.),\n        (1., 2., 3.),\n    ))\n    features = {'x': np.array(((42.,),))}\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):\n      training_loss = head.loss(\n          logits=logits_2x3,\n          labels=labels_2x2,\n          features=features,\n          mode=ModeKeys.EVAL)\n      self.evaluate(training_loss)\n    if tf.executing_eagerly():\n      return\n\n    # Dynamic shape only works in Graph mode.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.int64)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.loss(\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        features=features,\n        mode=ModeKeys.EVAL)\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 1\\] \\[labels_shape: \\] \\[2 2\\]'):\n        training_loss.eval({\n            logits_placeholder: logits_2x3,\n            labels_placeholder: labels_2x2\n        })\n\n  def test_invalid_labels_type(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 3).\n    # Labels should be shape (batch_size, 1).\n    labels_2x1 = np.array((\n        (1.,),\n        (1.,),\n    ))\n    logits_2x3 = np.array((\n        (1., 2., 3.),\n        (1., 2., 3.),\n    ))\n    features = {'x': np.array(((42.,),))}\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'Labels dtype'):\n      head.loss(\n          logits=logits_2x3,\n          labels=labels_2x1,\n          features=features,\n          mode=ModeKeys.EVAL)\n    if tf.executing_eagerly():\n      return\n\n    # Dynamic shape only works in Graph mode.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    with self.assertRaisesRegexp(ValueError, 'Labels dtype'):\n      head.loss(\n          logits=logits_placeholder,\n          labels=labels_placeholder,\n          features=features,\n          mode=ModeKeys.EVAL)\n\n  def test_invalid_labels_values(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    labels_2x1_with_large_id = np.array((\n        (45,),\n        (1,),\n    ), dtype=int)\n    labels_2x1_with_negative_id = np.array((\n        (-5,),\n        (1,),\n    ), dtype=int)\n    logits_2x3 = np.array((\n        (1., 2., 4.),\n        (1., 2., 3.),\n    ))\n    features = {'x': np.array(((42.,),))}\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'Labels must be <= 3 - 1'):\n        training_loss = head.loss(\n            logits=logits_2x3,\n            labels=labels_2x1_with_large_id,\n            features=features,\n            mode=ModeKeys.EVAL)\n\n      with self.assertRaisesRegexp(ValueError, 'Labels must be >= 0'):\n        training_loss = head.loss(\n            logits=logits_2x3,\n            labels=labels_2x1_with_negative_id,\n            features=features,\n            mode=ModeKeys.EVAL)\n      return\n\n    # Dynamic shape only works in Graph mode.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.int64)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.loss(\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        features=features,\n        mode=ModeKeys.EVAL)\n    with self.cached_session():\n      with self.assertRaisesOpError('Labels must be <= n_classes - 1'):\n        training_loss.eval({\n            labels_placeholder: labels_2x1_with_large_id,\n            logits_placeholder: logits_2x3\n        })\n\n    with self.cached_session():\n      with self.assertRaisesOpError('Labels must be >= 0'):\n        training_loss.eval({\n            labels_placeholder: labels_2x1_with_negative_id,\n            logits_placeholder: logits_2x3\n        })\n\n  def test_invalid_labels_sparse_tensor(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    labels_2x1 = tf.sparse.SparseTensor(\n        values=['english', 'italian'],\n        indices=[[0, 0], [1, 0]],\n        dense_shape=[2, 1])\n    logits_2x3 = np.array((\n        (1., 2., 4.),\n        (1., 2., 3.),\n    ))\n\n    with self.assertRaisesRegexp(ValueError,\n                                 'SparseTensor labels are not supported.'):\n      loss = head.loss(\n          logits=logits_2x3,\n          labels=labels_2x1,\n          features={'x': np.array(((42.,),))},\n          mode=ModeKeys.EVAL)\n      self.evaluate(loss)\n\n  def test_incompatible_labels_shape(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    # Logits should be shape (batch_size, 3).\n    # Labels should be shape (batch_size, 1).\n    # Here batch sizes are different.\n    values_3x1 = np.array((\n        (1,),\n        (1,),\n        (1,),\n    ))\n    values_2x3 = np.array((\n        (1., 2., 3.),\n        (1., 2., 3.),\n    ))\n    features = {'x': values_2x3}\n\n    # Static shape.\n    # Eager mode.\n    if tf.executing_eagerly():\n      with self.assertRaisesRegex(ValueError, 'labels shape'):\n        head.loss(\n            logits=values_2x3,\n            labels=values_3x1,\n            features=features,\n            mode=ModeKeys.EVAL)\n      return\n    # Graph mode.\n    with self.assertRaisesRegex(ValueError, r'shape.*\\(3,\\).*\\(2, 3\\)'):\n      head.loss(\n          logits=values_2x3,\n          labels=values_3x1,\n          features=features,\n          mode=ModeKeys.EVAL)\n\n    # Dynamic shape only works in Graph mode.\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.int64)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    training_loss = head.loss(\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        features=features,\n        mode=ModeKeys.EVAL)\n    with self.cached_session():\n      with self.assertRaisesRegex(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 1\\] \\[labels_shape: \\] \\[3 1\\]'):\n        training_loss.eval({\n            labels_placeholder: values_3x1,\n            logits_placeholder: values_2x3\n        })\n\n  def test_predict(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    logits = [[1., 0., 0.], [0., 0., 1.]]\n    expected_probabilities = [[0.576117, 0.2119416, 0.2119416],\n                              [0.2119416, 0.2119416, 0.576117]]\n    expected_class_ids = [[0], [2]]\n    expected_all_class_ids = [[0, 1, 2]] * 2\n    expected_classes = [[b'0'], [b'2']]\n    expected_all_classes = [[b'0', b'1', b'2']] * 2\n    expected_export_classes = [[b'0', b'1', b'2']] * 2\n\n    keys = prediction_keys.PredictionKeys\n    preds = head.predictions(logits)\n    self.assertAllClose(logits, self.evaluate(preds[keys.LOGITS]))\n    self.assertAllClose(expected_probabilities,\n                        self.evaluate(preds[keys.PROBABILITIES]))\n    self.assertAllClose(expected_class_ids,\n                        self.evaluate(preds[keys.CLASS_IDS]))\n    self.assertAllEqual(expected_classes, self.evaluate(preds[keys.CLASSES]))\n    self.assertAllClose(expected_all_class_ids,\n                        self.evaluate(preds[keys.ALL_CLASS_IDS]))\n    self.assertAllEqual(expected_all_classes,\n                        self.evaluate(preds[keys.ALL_CLASSES]))\n    if tf.executing_eagerly():\n      return\n\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    self.assertItemsEqual(\n        (test_lib._DEFAULT_SERVING_KEY, 'predict', 'classification'),\n        spec.export_outputs.keys())\n\n    # Assert predictions and export_outputs.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(logits, predictions[keys.LOGITS])\n      self.assertAllClose(expected_probabilities,\n                          predictions[keys.PROBABILITIES])\n      self.assertAllClose(expected_class_ids, predictions[keys.CLASS_IDS])\n      self.assertAllEqual(expected_classes, predictions[keys.CLASSES])\n      self.assertAllClose(expected_all_class_ids,\n                          predictions[keys.ALL_CLASS_IDS])\n      self.assertAllEqual(expected_all_classes, predictions[keys.ALL_CLASSES])\n\n      self.assertAllClose(\n          expected_probabilities,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].scores))\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].classes))\n\n  def test_predict_with_tensor_n_classes(self):\n    n_classes = tf.constant(3, dtype=tf.dtypes.int32)\n    head = head_lib.MultiClassHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    logits = [[1., 0., 0.], [0., 0., 1.]]\n    expected_probabilities = [[0.576117, 0.2119416, 0.2119416],\n                              [0.2119416, 0.2119416, 0.576117]]\n    expected_class_ids = [[0], [2]]\n    expected_all_class_ids = [[0, 1, 2]] * 2\n    expected_classes = [[b'0'], [b'2']]\n    expected_all_classes = [[b'0', b'1', b'2']] * 2\n    expected_export_classes = [[b'0', b'1', b'2']] * 2\n\n    keys = prediction_keys.PredictionKeys\n    preds = head.predictions(logits)\n    self.assertAllClose(logits, self.evaluate(preds[keys.LOGITS]))\n    self.assertAllClose(expected_probabilities,\n                        self.evaluate(preds[keys.PROBABILITIES]))\n    self.assertAllClose(expected_class_ids,\n                        self.evaluate(preds[keys.CLASS_IDS]))\n    self.assertAllEqual(expected_classes, self.evaluate(preds[keys.CLASSES]))\n    self.assertAllClose(expected_all_class_ids,\n                        self.evaluate(preds[keys.ALL_CLASS_IDS]))\n    self.assertAllEqual(expected_all_classes,\n                        self.evaluate(preds[keys.ALL_CLASSES]))\n    if tf.executing_eagerly():\n      return\n\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    self.assertItemsEqual(\n        (test_lib._DEFAULT_SERVING_KEY, 'predict', 'classification'),\n        spec.export_outputs.keys())\n\n    # Assert predictions and export_outputs.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(logits, predictions[keys.LOGITS])\n      self.assertAllClose(expected_probabilities,\n                          predictions[keys.PROBABILITIES])\n      self.assertAllClose(expected_class_ids, predictions[keys.CLASS_IDS])\n      self.assertAllEqual(expected_classes, predictions[keys.CLASSES])\n      self.assertAllClose(expected_all_class_ids,\n                          predictions[keys.ALL_CLASS_IDS])\n      self.assertAllEqual(expected_all_classes, predictions[keys.ALL_CLASSES])\n\n      self.assertAllClose(\n          expected_probabilities,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].scores))\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].classes))\n\n  def test_predict_with_invalid_keys(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    logits = [[1., 0., 0.], [0., 0., 1.]]\n    with self.assertRaisesRegexp(\n        ValueError,\n        r'Prediction key must be in PredictionKeys, given: some_invalid_key'):\n      preds = head.predictions(logits, ['some_invalid_key'])\n      self.evaluate(preds)\n\n  def test_predict_with_vocabulary_list(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n\n    logits = [[1., 0., 0.], [0., 0., 1.]]\n    expected_classes = [[b'aang'], [b'zuko']]\n    expected_export_classes = [[b'aang', b'iroh', b'zuko']] * 2\n    pred_key = prediction_keys.PredictionKeys.CLASSES\n    if tf.executing_eagerly():\n      preds = head.predictions(logits, [pred_key])\n      self.assertAllEqual(expected_classes,\n                          preds[prediction_keys.PredictionKeys.CLASSES])\n      return\n\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertAllEqual(expected_classes,\n                          sess.run(spec.predictions[pred_key]))\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].classes))\n\n  def test_weight_should_not_impact_prediction(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes, weight_column='label_weights')\n    logits = [[1., 0., 0.], [0., 0., 1.]]\n    expected_probabilities = [[0.576117, 0.2119416, 0.2119416],\n                              [0.2119416, 0.2119416, 0.576117]]\n    weights_2x1 = [[1.], [2.]]\n    features = {\n        'x': np.array(((42,),), dtype=np.int32),\n        'label_weights': weights_2x1,\n    }\n\n    keys = prediction_keys.PredictionKeys\n    preds = head.predictions(logits, [keys.LOGITS, keys.PROBABILITIES])\n    self.assertAllClose(logits, self.evaluate(preds[keys.LOGITS]))\n    self.assertAllClose(expected_probabilities,\n                        self.evaluate(preds[keys.PROBABILITIES]))\n    if tf.executing_eagerly():\n      return\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(logits, predictions[keys.LOGITS])\n      self.assertAllClose(expected_probabilities,\n                          predictions[keys.PROBABILITIES])\n\n  def test_eval_create_loss(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n\n    # logits: [2, 3], labels: [2, 1]\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size = 10 / 2 = 5.\n    expected_training_loss = 5.\n    # Create loss.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    self.assertAllClose(\n        expected_training_loss,\n        self.evaluate(training_loss),\n        rtol=1e-2,\n        atol=1e-2)\n\n  def test_eval_create_loss_loss_fn(self):\n    \"\"\"Tests head.loss for eval mode and custom loss_fn.\"\"\"\n    loss = np.array([[1.], [2.]], dtype=np.float32)\n    logits_input = np.array([[-10., 10., 0.], [-15., 10., 0]], dtype=np.float32)\n    labels_input = np.array([[1], [2]], dtype=np.int64)\n\n    def _loss_fn(labels, logits):\n      check_labels = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(labels, labels_input)), data=[labels])\n      check_logits = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(logits, logits_input)), data=[logits])\n      with tf.control_dependencies([check_labels, check_logits]):\n        return tf.constant(loss)\n\n    head = head_lib.MultiClassHead(n_classes=3, loss_fn=_loss_fn)\n\n    actual_training_loss = head.loss(\n        logits=logits_input,\n        labels=labels_input,\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.EVAL)\n    self.assertAllClose(np.sum(loss) / 2., self.evaluate(actual_training_loss))\n\n  def test_eval_create_loss_loss_fn_wrong_shape(self):\n    \"\"\"Tests custom loss_fn that returns Tensor of unexpected shape.\"\"\"\n    loss = np.array([1., 2.], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      del labels, logits  # Unused\n      return tf.constant(loss)\n\n    head = head_lib.MultiClassHead(n_classes=3, loss_fn=_loss_fn)\n\n    logits = np.array([[-10., 10., 0.], [-15., 10., 0.]], dtype=np.float32)\n    labels = np.array([[1], [2]], dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'loss_shape'):\n        head.loss(logits=logits, labels=labels, features=features)\n    else:\n      actual_training_loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[loss_fn must return Tensor of shape \\[D0, D1, ... DN, 1\\]\\. \\] '\n          r'\\[logits_shape: \\] \\[2 3\\] \\[loss_shape: \\] \\[2\\]'):\n        self.evaluate(actual_training_loss)\n\n  def test_eval_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib.MultiClassHead(n_classes=3)\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.loss(\n          logits=np.array((\n              (10, 0, 0),\n              (0, 10, 0),\n          ), dtype=np.float32),\n          labels=None,\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.EVAL)\n\n  def test_eval(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum(10, 0) / 2 = 5.\n    expected_loss = 5.\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_loss,\n        keys.ACCURACY: 0.5,  # 1 of 2 labels is correct.\n    }\n    tol = 1e-2\n\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      self.assertItemsEqual(expected_metrics.keys(), updated_metrics.keys())\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics},\n          rtol=tol,\n          atol=tol)\n      loss = head.loss(labels, logits, features=features, mode=ModeKeys.EVAL)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_eval_metric_ops_with_head_name(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes, name='some_multiclass_head')\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    expected_metric_keys = [\n        '{}/some_multiclass_head'.format(metric_keys.MetricKeys.LOSS_MEAN),\n        '{}/some_multiclass_head'.format(metric_keys.MetricKeys.ACCURACY)\n    ]\n\n    eval_metrics = head.metrics()\n    updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                          labels)\n    self.assertItemsEqual(expected_metric_keys, updated_metrics.keys())\n\n  def test_eval_with_regularization_losses(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size\n    #                    = sum(10, 0) / 2 = 5.\n    expected_unregularized_loss = 5.\n    expected_regularized_loss = (\n        expected_unregularized_loss + expected_regularization_loss)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_unregularized_loss,\n        keys.LOSS_REGULARIZATION: expected_regularization_loss,\n        keys.ACCURACY: 0.5,  # 1 of 2 labels is correct.\n    }\n    tol = 1e-2\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics(regularization_losses=regularization_losses)\n      updated_metrics = head.update_metrics(\n          eval_metrics,\n          features,\n          logits,\n          labels,\n          regularization_losses=regularization_losses)\n      # Assert metrics.\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics},\n          rtol=tol,\n          atol=tol)\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        regularization_losses=regularization_losses,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_regularized_loss, loss, rtol=tol, atol=tol)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_eval_with_label_vocabulary_create_loss(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n    logits = [[10., 0, 0], [0, 10, 0]]\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size = [5.0, 0].\n    expected_training_loss = 5.\n    if tf.executing_eagerly():\n      training_loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n      self.assertAllClose(\n          expected_training_loss, training_loss, rtol=1e-2, atol=1e-2)\n    else:\n      training_loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n      with self.cached_session():\n        test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n        self.assertAllClose(\n            expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n\n  def test_eval_with_label_vocabulary(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n\n    logits = [[10., 0, 0], [0, 10, 0]]\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = sum(cross_entropy(labels, logits))  / batch_size\n    #      = sum(10, 0) / 2 = 5.\n    expected_loss = 5.\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_loss,\n        keys.ACCURACY: 0.5,  # 1 of 2 labels is correct.\n    }\n    tol = 1e-2\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n      self.assertAllClose(\n          expected_loss, self.evaluate(loss), rtol=tol, atol=tol)\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics},\n          rtol=tol,\n          atol=tol)\n      return\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_weighted_multi_example_eval(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes, weight_column='label_weights')\n\n    # Create estimator spec.\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n        (0, 0, 10),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (2,), (2,)), dtype=np.int64)\n    weights_3x1 = np.array(((1.,), (2.,), (3.,)), dtype=np.float64)\n    # weighted_loss = sum(cross_entropy(labels, logits) *  weights)\n    #      = sum([10, 10, 0] * [1, 2, 3])\n    #      = sum([10, 20, 0]) = 30.\n    # loss = weighted_loss  / batch_size = 30 / 3 = 10\n    # loss_mean = weighted_loss / sum(weights) = 30 / 6 = 5\n    expected_loss = 10.\n    features = {\n        'x': np.array(((42,),), dtype=np.int32),\n        'label_weights': weights_3x1\n    }\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: 30. / np.sum(weights_3x1),\n        # Weighted accuracy is 1 * 3.0 / sum weights = 0.5\n        keys.ACCURACY: 0.5,\n    }\n\n    tol = 1e-2\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      self.assertItemsEqual(expected_metrics.keys(), updated_metrics.keys())\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics},\n          rtol=tol,\n          atol=tol)\n      return\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_train_create_loss(self):\n    head = head_lib.MultiClassHead(n_classes=3)\n\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # unreduced_loss = cross_entropy(labels, logits) = [10, 0].\n    expected_unreduced_loss = [[10.], [0.]]\n    # Weights default to 1.\n    expected_weights = 1.\n    # training_loss = (1 * 10 + 1 * 0) / 2 = 5.\n    expected_training_loss = 5.\n    tol = 1e-2\n    if tf.executing_eagerly():\n      training_loss = head.loss(labels, logits, features)\n      self.assertAllClose(\n          expected_training_loss, training_loss, rtol=tol, atol=tol)\n      unreduced_loss, actual_weights = head._unweighted_loss_and_weights(\n          logits, labels, features)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss, rtol=tol, atol=tol)\n      self.assertAllClose(expected_weights, actual_weights)\n      return\n\n    training_loss = head.loss(labels, logits, features)\n    unreduced_loss, actual_weights = head._unweighted_loss_and_weights(\n        logits, labels, features)\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(expected_weights, actual_weights)\n\n  def test_train_create_loss_loss_reduction(self):\n    \"\"\"Tests create_loss with loss_reduction.\"\"\"\n    head = head_lib.MultiClassHead(\n        n_classes=3, loss_reduction=tf.losses.Reduction.SUM)\n\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # unreduced_loss = cross_entropy(labels, logits) = [10, 0].\n    expected_unreduced_loss = [[10.], [0.]]\n    # Weights default to 1.\n    expected_weights = 1.\n    # training_loss = 1 * 10 + 1 * 0\n    expected_training_loss = 10.\n    tol = 1e-2\n    if tf.executing_eagerly():\n      training_loss = head.loss(labels, logits, features)\n      self.assertAllClose(\n          expected_training_loss, training_loss, rtol=tol, atol=tol)\n      unreduced_loss, actual_weights = head._unweighted_loss_and_weights(\n          logits, labels, features)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss, rtol=tol, atol=tol)\n      self.assertAllClose(expected_weights, actual_weights)\n      return\n\n    training_loss = head.loss(labels, logits, features)\n    unreduced_loss, actual_weights = head._unweighted_loss_and_weights(\n        logits, labels, features)\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(\n          expected_unreduced_loss, unreduced_loss.eval(), rtol=tol, atol=tol)\n      self.assertAllClose(expected_weights, actual_weights)\n\n  def test_train_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib.MultiClassHead(n_classes=3)\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.loss(\n          logits=np.array((\n              (10, 0, 0),\n              (0, 10, 0),\n          ), dtype=np.float32),\n          labels=None,\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN)\n\n  def test_train(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum(10, 0) / 2 = 5.\n    expected_loss = 5.\n    tol = 1e-2\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str, tol)\n\n  def test_train_with_regularization_losses(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = sum(cross_entropy(labels, logits)) / batch_size\n    #                    = sum(10, 0) / 2 = 5.\n    # loss = unregularized_loss + regularization_loss = 7.\n    expected_loss = 7.\n    tol = 1e-2\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits,\n          labels=labels,\n          features=features,\n          mode=ModeKeys.TRAIN,\n          regularization_losses=regularization_losses)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        regularization_losses=regularization_losses,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS:\n                  expected_loss,\n              metric_keys.MetricKeys.LOSS_REGULARIZATION:\n                  (expected_regularization_loss),\n          }, summary_str, tol)\n\n  def test_train_one_dim_create_loss(self):\n    \"\"\"Tests create_loss with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib.MultiClassHead(n_classes=3, weight_column='label_weights')\n\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n        (0, 0, 10),\n    ), dtype=np.float32)\n    labels_rank_1 = np.array((\n        1,\n        2,\n        2,\n    ), dtype=np.int64)\n    weights_rank_1 = np.array((\n        1.,\n        2.,\n        3.,\n    ), dtype=np.float64)\n    features = {\n        'x': np.array(((42,),), dtype=np.float32),\n        'label_weights': weights_rank_1\n    }\n\n    # unreduced_loss = cross_entropy(labels, logits) = [10, 10, 0].\n    # weights are reshaped to [3, 1] to match logits.\n    # training_loss = sum(1 * 10 + 2 * 10 + 3 * 0) / batch_size = 30. / 3 = 10.\n    expected_training_loss = 10.\n    tol = 1e-2\n\n    if tf.executing_eagerly():\n      training_loss = head.loss(labels_rank_1, logits, features)\n      self.assertAllClose(\n          expected_training_loss, training_loss, rtol=tol, atol=tol)\n      return\n\n    training_loss = head.loss(labels_rank_1, logits, features)\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n\n  def test_train_one_dim(self):\n    \"\"\"Tests train with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib.MultiClassHead(n_classes=3, weight_column='label_weights')\n\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n        (0, 0, 10),\n    ), dtype=np.float32)\n    labels_rank_1 = np.array((\n        1,\n        2,\n        2,\n    ), dtype=np.int64)\n    weights_rank_1 = np.array((\n        1.,\n        2.,\n        3.,\n    ), dtype=np.float64)\n\n    self.assertEqual((3,), labels_rank_1.shape)\n    self.assertEqual((3,), weights_rank_1.shape)\n    # loss = sum(cross_entropy(labels, logits) * [1, 2, 3]) / batch_size\n    #      = sum([10, 10, 0] * [1, 2, 3]) / 3 = 30 / 3 = 10.\n    expected_loss = 10.\n    features = {\n        'x': np.array(((42,),), dtype=np.float32),\n        'label_weights': weights_rank_1\n    }\n    tol = 1e-2\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits,\n          labels=labels_rank_1,\n          features=features,\n          mode=ModeKeys.TRAIN)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str, tol)\n\n  def test_train_with_vocabulary_create_loss(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n\n    logits = [[10., 0, 0], [0, 10, 0]]\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size = 10 / 2 = 5.\n    expected_training_loss = 5.\n    if tf.executing_eagerly():\n      training_loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n      self.assertAllClose(\n          expected_training_loss, training_loss, rtol=1e-2, atol=1e-2)\n      return\n\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=1e-2, atol=1e-2)\n\n  def test_train_with_vocabulary(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n\n    logits = [[10., 0, 0], [0, 10, 0]]\n    labels = [[b'iroh'], [b'iroh']]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum(10, 0) / 2 = 5.\n    expected_loss = 5.\n    tol = 1e-2\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      return\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss = sess.run(spec.loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n\n  def test_weighted_multi_example_train(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes, weight_column='label_weights')\n\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n        (0, 0, 10),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (2,), (2,)), dtype=np.int64)\n    weights_3x1 = np.array(((1.,), (2.,), (3.,)), dtype=np.float64)\n    expected_train_result = 'my_train_op'\n    # loss = sum(cross_entropy(labels, logits) * [1, 2, 3]) / batch_size\n    #      = sum([10, 10, 0] * [1, 2, 3]) / 3 = 30 / 3 = 10\n    expected_loss = 10.\n    tol = 1e-2\n    features = {\n        'x': np.array(((42,),), dtype=np.float32),\n        'label_weights': weights_3x1\n    }\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      return\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str, tol)\n\n  def test_multi_dim_weighted_train_create_loss(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib.MultiClassHead(n_classes=3, weight_column='weights')\n\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n\n    # unreduced_loss = cross_entropy(labels, logits) = [[0, 12], [0, 15]].\n    # weights are reshaped to [2, 2, 1] to match logits.\n    # training_loss = sum(1*0 + 1.5*12 + 2*0 + 2.5*15) / batch_size\n    #               = 55.5 / (2*2) = 13.875\n    expected_training_loss = 13.875\n    tol = 1e-2\n    if tf.executing_eagerly():\n      training_loss = head.loss(labels, logits, features={'weights': weights})\n      self.assertAllClose(\n          expected_training_loss, training_loss, rtol=tol, atol=tol)\n      return\n\n    training_loss = head.loss(labels, logits, features={'weights': weights})\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(\n          expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)\n\n  def test_multi_dim_weighted_train(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib.MultiClassHead(n_classes=3, weight_column='weights')\n\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    tol = 1e-2\n    # loss = cross_entropy(labels, logits) = [[0, 12], [0, 15]].\n    # weighted_sum_loss = (1*0 + 1.5*12 + 2*0 + 2.5*15) = 55.5\n    # training_loss = weighted_sum_loss / batch_size  = 55.5 / (2*2) = 13.875\n    expected_loss = 13.875\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits,\n          labels=labels,\n          features={'weights': weights},\n          mode=ModeKeys.TRAIN)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=2)\n      ])\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n\n  def test_multi_dim_train_weights_wrong_inner_dim(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 1].\"\"\"\n    head = head_lib.MultiClassHead(n_classes=3, weight_column='weights')\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[1.], [2.]], dtype=np.float32)\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'weights shape'):\n        head.loss(\n            logits=logits,\n            labels=labels,\n            features={'weights': weights},\n            mode=ModeKeys.TRAIN)\n      return\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\] \\[2 2 3\\] \\[weights_shape: \\] \\[2 1\\]'):\n        spec.loss.eval()\n\n  def test_multi_dim_train_weights_wrong_outer_dim(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2, 3].\"\"\"\n    head = head_lib.MultiClassHead(n_classes=3, weight_column='weights')\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[[1., 1.1, 1.2], [1.5, 1.6, 1.7]],\n                        [[2., 2.1, 2.2], [2.5, 2.6, 2.7]]])\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'weights shape'):\n        head.loss(\n            logits=logits,\n            labels=labels,\n            features={'weights': weights},\n            mode=ModeKeys.TRAIN)\n      return\n\n    weights_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights_placeholder},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\]\\s\\[2 2 3\\]\\s\\[weights_shape: \\]\\s\\[2 2 3\\]'):\n        spec.loss.eval({weights_placeholder: weights})\n\n  def test_multi_dim_weighted_eval(self):\n    \"\"\"Logits of shape [2, 2, 2], labels [2, 2, 1], weights [2, 2].\"\"\"\n    head = head_lib.MultiClassHead(n_classes=3, weight_column='weights')\n    logits = np.array([[[10, 0, 0], [12, 0, 0]], [[0, 10, 0], [0, 15, 0]]],\n                      dtype=np.float32)\n    labels = np.array([[[0], [1]], [[1], [2]]], dtype=np.int64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    # loss = cross_entropy(labels, logits) = [[0, 12], [0, 15]].\n    # weighted_sum_loss = 1*0 + 1.5*12 + 2*0 + 2.5*15 = 55.5\n    # training_loss = weighted_sum_loss / batch_size = 55.5 / (2*2) = 13.875\n    expected_loss = 13.875\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN:\n            55.5 / np.sum(weights),\n        keys.ACCURACY:\n            (1. * 1. + 1.5 * 0. + 2. * 1. + 2.5 * 0.) / np.sum(weights),\n    }\n    tol = 1e-2\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits,\n          labels=labels,\n          features={'weights': weights},\n          mode=ModeKeys.TRAIN)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(\n          eval_metrics,\n          features={'weights': weights},\n          logits=logits,\n          labels=labels)\n      # Assert metrics.\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics},\n          rtol=tol,\n          atol=tol)\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n\n@test_util.deprecated_graph_mode_only\nclass MultiClassHeadForEstimator(tf.test.TestCase):\n  \"\"\"Tests for create_estimator_spec running in Graph mode only.\"\"\"\n\n  def test_invalid_trainable_variables(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params\n        return [\n            tf.strings.join([\n                tf.constant('my_train_op'),\n                tf.strings.as_string(loss, precision=2)\n            ])\n        ]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'trainable_variables cannot be None'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array((\n              (10, 0, 0),\n              (0, 10, 0),\n          ), dtype=np.float32),\n          labels=np.array(((1,), (1,)), dtype=np.int64),\n          optimizer=_Optimizer('my_optimizer'),\n          trainable_variables=None)\n    with self.assertRaisesRegexp(\n        ValueError, r'trainable_variables should be a list or a tuple'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array((\n              (10, 0, 0),\n              (0, 10, 0),\n          ), dtype=np.float32),\n          labels=np.array(((1,), (1,)), dtype=np.int64),\n          optimizer=_Optimizer('my_optimizer'),\n          trainable_variables={\n              'var_list': [tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)]\n          })\n\n  def test_train_with_optimizer(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes)\n\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    expected_train_result = 'my_train_op'\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params\n        return [\n            tf.strings.join([\n                tf.constant(expected_train_result),\n                tf.strings.as_string(loss, precision=2)\n            ])\n        ]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    # loss = sum(cross_entropy(labels, logits)) / batch_size\n    #      = sum(10, 0) / 2 = 5.\n    expected_loss = 5.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        optimizer=_Optimizer('my_optimizer'),\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    tol = 1e-2\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.2f}'.format(expected_train_result, expected_loss)),\n          train_result)\n\n  def test_train_with_update_ops(self):\n    n_classes = 3\n    with tf.Graph().as_default():\n      w = tf.Variable(1)\n      update_op = w.assign_add(1)\n\n      t = tf.Variable('')\n      expected_train_result = b'my_train_op'\n\n      def _train_op_fn(loss):\n        del loss\n        return t.assign(expected_train_result)\n\n      head = head_lib.MultiClassHead(n_classes)\n\n      spec = head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array((\n              (10, 0, 0),\n              (0, 10, 0),\n          ), dtype=np.float32),\n          labels=np.array(((1,), (1,)), dtype=np.int64),\n          train_op_fn=_train_op_fn,\n          update_ops=[update_op],\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n\n      with self.cached_session() as sess:\n        test_lib._initialize_variables(self, spec.scaffold)\n        sess.run(spec.train_op)\n        w_value, t_value = sess.run([w, t])\n        self.assertEqual(2, w_value)\n        self.assertEqual(expected_train_result, t_value)\n\n  def test_train_summaries_with_head_name(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(n_classes, name='some_multiclass_head')\n\n    logits = np.array((\n        (10, 0, 0),\n        (0, 10, 0),\n    ), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    # loss = sum(cross_entropy(labels, logits)) / batch_size= sum(10, 0) / 2 = 5\n    expected_loss = 5.\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert summaries.\n    tol = 1e-2\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      summary_str = sess.run(spec.scaffold.summary_op)\n      test_lib._assert_simple_summaries(\n          self, {\n              '{}/some_multiclass_head'.format(metric_keys.MetricKeys.LOSS):\n                  expected_loss,\n          }, summary_str, tol)\n\n  def test_lookup_tables_in_graph(self):\n    n_classes = 3\n    head = head_lib.MultiClassHead(\n        n_classes, label_vocabulary=['aang', 'iroh', 'zuko'])\n\n    feature_columns = [tf.feature_column.numeric_column('x')]\n    # Create dnn estimator.\n    est = dnn.DNNEstimatorV2(\n        head=head, hidden_units=(2, 2), feature_columns=feature_columns)\n\n    def input_fn():\n      return ({\n          'x': np.array((\n              (42,),\n              (43,),\n          ), dtype=np.int32)\n      }, [[b'iroh'], [b'iroh']])\n\n    # Train.\n    num_steps = 1\n    est.train(input_fn, steps=num_steps)\n    # Eval.\n    eval_results = est.evaluate(input_fn, steps=num_steps)\n    self.assertEqual(num_steps,\n                     eval_results[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(eval_results))\n    # Predict.\n    est.predict(input_fn)\n\n  def test_missmatch_n_classes_label_vocabulary(self):\n    with self.assertRaises(ValueError):\n      head_lib.MultiClassHead(\n          n_classes=3, label_vocabulary=['a', 'b', 'c', 'd'])\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/multi_head.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Multi head class.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.head import base_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\ndef _no_op_train_fn(loss):\n  del loss\n  return tf.no_op()\n\n\ndef _default_export_output(export_outputs, head_name):\n  \"\"\"Extracts the default export output from the given export_outputs dict.\"\"\"\n  if len(export_outputs) == 1:\n    return next(six.itervalues(export_outputs))\n  try:\n    return export_outputs[base_head.DEFAULT_SERVING_KEY]\n  except KeyError:\n    raise ValueError(\n        '{} did not specify default export_outputs. '\n        'Given: {} '\n        'Suggested fix: Use one of the heads in tf.estimator, or include '\n        'key {} in export_outputs.'.format(head_name, export_outputs,\n                                           base_head.DEFAULT_SERVING_KEY))\n\n\n@estimator_export('estimator.MultiHead')\nclass MultiHead(base_head.Head):\n  \"\"\"Creates a `Head` for multi-objective learning.\n\n  This class merges the output of multiple `Head` objects. Specifically:\n\n  * For training, sums losses of each head, calls `train_op_fn` with this\n    final loss.\n  * For eval, merges metrics by adding `head.name` suffix to the keys in eval\n    metrics, such as `precision/head1.name`, `precision/head2.name`.\n  * For prediction, merges predictions and updates keys in prediction dict to a\n    2-tuple, `(head.name, prediction_key)`. Merges `export_outputs` such that\n    by default the first head is served.\n\n  Usage:\n\n  >>> head1 = tf.estimator.MultiLabelHead(n_classes=2, name='head1')\n  >>> head2 = tf.estimator.MultiLabelHead(n_classes=3, name='head2')\n  >>> multi_head = tf.estimator.MultiHead([head1, head2])\n  >>> logits = {\n  ...    'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n  ...    'head2': np.array([[20., -20., 20.], [-30., 20., -20.]],\n  ...    dtype=np.float32),}\n  >>> labels = {\n  ...    'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),\n  ...    'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),}\n  >>> features = {'x': np.array(((42,),), dtype=np.float32)}\n  >>> # For large logits, sigmoid cross entropy loss is approximated as:\n  >>> # loss = labels * (logits < 0) * (-logits) +\n  >>> #        (1 - labels) * (logits > 0) * logits =>\n  >>> # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]\n  >>> # loss1 = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75\n  >>> # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]\n  >>> # loss2 = ((20 + 20 + 20) / 3 + (30 + 0 + 0) / 3) / 2 = 15.00\n  >>> # loss = loss1 + loss2 = 8.75 + 15.00 = 23.75\n  >>> loss = multi_head.loss(labels, logits, features=features)\n  >>> print('{:.2f}'.format(loss.numpy()))\n  23.75\n  >>> eval_metrics = multi_head.metrics()\n  >>> updated_metrics = multi_head.update_metrics(\n  ...   eval_metrics, features, logits, labels)\n  >>> for k in sorted(updated_metrics):\n  ...  print('{} : {:.2f}'.format(k, updated_metrics[k].result().numpy()))\n  auc/head1 : 0.17\n  auc/head2 : 0.33\n  auc_precision_recall/head1 : 0.60\n  auc_precision_recall/head2 : 0.40\n  average_loss/head1 : 8.75\n  average_loss/head2 : 15.00\n  loss/head1 : 8.75\n  loss/head2 : 15.00\n  >>> preds = multi_head.predictions(logits)\n  >>> print(preds[('head1', 'logits')])\n  tf.Tensor(\n    [[-10.  10.]\n     [-15.  10.]], shape=(2, 2), dtype=float32)\n\n  Usage with a canned estimator:\n\n  ```python\n  # In `input_fn`, specify labels as a dict keyed by head name:\n  def input_fn():\n    features = ...\n    labels1 = ...\n    labels2 = ...\n    return features, {'head1.name': labels1, 'head2.name': labels2}\n\n  # In `model_fn`, specify logits as a dict keyed by head name:\n  def model_fn(features, labels, mode):\n    # Create simple heads and specify head name.\n    head1 = tf.estimator.MultiClassHead(n_classes=3, name='head1')\n    head2 = tf.estimator.BinaryClassHead(name='head2')\n    # Create MultiHead from two simple heads.\n    head = tf.estimator.MultiHead([head1, head2])\n    # Create logits for each head, and combine them into a dict.\n    logits1, logits2 = logit_fn()\n    logits = {'head1.name': logits1, 'head2.name': logits2}\n    # Return the merged EstimatorSpec\n    return head.create_estimator_spec(..., logits=logits, ...)\n\n  # Create an estimator with this model_fn.\n  estimator = tf.estimator.Estimator(model_fn=model_fn)\n  estimator.train(input_fn=input_fn)\n  ```\n\n  Also supports `logits` as a `Tensor` of shape\n  `[D0, D1, ... DN, logits_dimension]`. It will split the `Tensor` along the\n  last dimension and distribute it appropriately among the heads. E.g.:\n\n  ```python\n  # Input logits.\n  logits = np.array([[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]],\n                    dtype=np.float32)\n  # Suppose head1 and head2 have the following logits dimension.\n  head1.logits_dimension = 2\n  head2.logits_dimension = 3\n  # After splitting, the result will be:\n  logits_dict = {'head1_name': [[-1., 1.], [-1.5, 1.]],\n                 'head2_name':  [[2., -2., 2.], [-3., 2., -2.]]}\n  ```\n\n  Usage:\n\n  ```python\n  def model_fn(features, labels, mode):\n    # Create simple heads and specify head name.\n    head1 = tf.estimator.MultiClassHead(n_classes=3, name='head1')\n    head2 = tf.estimator.BinaryClassHead(name='head2')\n    # Create multi-head from two simple heads.\n    head = tf.estimator.MultiHead([head1, head2])\n    # Create logits for the multihead. The result of logits is a `Tensor`.\n    logits = logit_fn(logits_dimension=head.logits_dimension)\n    # Return the merged EstimatorSpec\n    return head.create_estimator_spec(..., logits=logits, ...)\n  ```\n\n  Args:\n    heads: List or tuple of `Head` instances. All heads must have `name`\n      specified. The first head in the list is the default used at serving time.\n    head_weights: Optional list of weights, same length as `heads`. Used when\n      merging losses to calculate the weighted sum of losses from each head. If\n      `None`, all losses are weighted equally.\n  \"\"\"\n\n  def __init__(self, heads, head_weights=None):\n    if not heads:\n      raise ValueError('Must specify heads. Given: {}'.format(heads))\n    if head_weights:\n      if len(head_weights) != len(heads):\n        raise ValueError(\n            'heads and head_weights must have the same size. '\n            'Given len(heads): {}. Given len(head_weights): {}.'.format(\n                len(heads), len(head_weights)))\n    self._logits_dimension = 0\n    for head in heads:\n      if head.name is None:\n        raise ValueError(\n            'All given heads must have name specified. Given: {}'.format(head))\n      self._logits_dimension += head.logits_dimension\n    self._heads = tuple(heads)\n    self._head_weights = tuple(head_weights) if head_weights else tuple()\n    # Metric keys.\n    keys = metric_keys.MetricKeys\n    self._loss_regularization_key = self._summary_key(keys.LOSS_REGULARIZATION)\n    loss_keys = []\n    for head in self._heads:\n      loss_keys.append('{}/{}'.format(keys.LOSS, head.name))\n    self._loss_keys = tuple(loss_keys)\n\n  @property\n  def name(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return '_'.join([h.name for h in self._heads])\n\n  @property\n  def logits_dimension(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._logits_dimension\n\n  @property\n  def loss_reduction(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    loss_reductions = [head.loss_reduction for head in self._heads]\n    if len(set(loss_reductions)) > 1:\n      raise ValueError(\n          'The loss_reduction must be the same for different heads. '\n          'Given: {}'.format(loss_reductions))\n    return loss_reductions[0]\n\n  def _split_logits(self, logits):\n    \"\"\"Splits logits along the last dimension and returns a dict.\n\n    If the input logits is not a dict, splitting is applied based on the logits\n    dimension of each head.\n    For example:\n\n    ```python\n    # head1.logits_dimension = 2\n    # head2.logits_dimension = 3\n    head1 = tf.estimator.MultiLabelHead(n_classes=2, name='head1_name')\n    head2 = tf.estimator.MultiClassHead(n_classes=3, name='head2_name')\n    multi_head = tf.estimator.MultiHead([head1, head2])\n    # Input logits\n    logits = np.array([[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]],\n                      dtype=np.float32)\n    # As logits is not a dict, _split_logits is applied and returns the\n    # logits_dict as\n    logits_dict = {'head1_name': [[-1., 1.], [-1.5, 1.]],\n                   'head2_name':  [[2., -2., 2.], [-3., 2., -2.]]}\n    ```\n    Args:\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.\n        For many applications, the shape is `[batch_size, logits_dimension]`.\n\n    Returns:\n      logits_dict: A dict of logits for each head.\n    \"\"\"\n    logits_dict = {}\n    with ops.name_scope('split_logits', values=[logits]):\n      logits = ops.convert_to_tensor(logits)\n      logits_dimensions = [head.logits_dimension for head in self._heads]\n      total_logits_dimension = sum(logits_dimensions)\n      logits_tensor_shape = logits.shape.as_list()\n      last_dimension_size = logits_tensor_shape[-1]\n      if last_dimension_size is not None:\n        if last_dimension_size != total_logits_dimension:\n          raise ValueError(\n              'Could not split logits of shape %r among the heads with '\n              'individual logits dimensions: %r. The last dimension of the '\n              'logits tensor should equal %d but is %d.' %\n              ((logits_tensor_shape, logits_dimensions, last_dimension_size,\n                total_logits_dimension)))\n\n      # TODO(b/119617064): unify eager and graph implementations\n      if tf.executing_eagerly():\n        logits_shape = logits._shape_tuple()  # pylint: disable=protected-access\n        batch_shape = logits_shape[:-1]\n      else:\n        batch_shape = tf.compat.v1.shape(logits)[:-1]\n      zeros_like_batch_shape = tf.compat.v1.zeros_like(batch_shape)\n      minus_ones_like_batch_shape = -1 * tf.compat.v1.ones_like(batch_shape)\n      begin_idx = 0\n      for head in self._heads:\n        begin_tensor = tf.concat([zeros_like_batch_shape, [begin_idx]], axis=0)\n        size_tensor = tf.concat(\n            [minus_ones_like_batch_shape, [head.logits_dimension]], axis=0)\n        logits_dict[head.name] = tf.slice(\n            logits, begin=begin_tensor, size=size_tensor)\n        begin_idx += head.logits_dimension\n    return logits_dict\n\n  def _check_logits_and_labels(self, logits, labels=None):\n    \"\"\"Validates the keys of logits and labels.\"\"\"\n    head_names = []\n    for head in self._heads:\n      head_names.append(head.name)\n    # Checks logits keys and splits it if it's not a dict\n    if isinstance(logits, dict):\n      logits_missing_names = list(set(head_names) - set(list(logits)))\n      if logits_missing_names:\n        raise ValueError('logits has missing values for head(s): {}'.format(\n            logits_missing_names))\n      logits_dict = logits\n    else:\n      logits_dict = self._split_logits(logits)\n    # Checks labels type and its keys\n    if labels is not None:\n      if not isinstance(labels, dict):\n        raise ValueError('labels must be a dict. Given: {}'.format(labels))\n      labels_missing_names = list(set(head_names) - set(list(labels)))\n      if labels_missing_names:\n        raise ValueError('labels has missing values for head(s): {}'.format(\n            labels_missing_names))\n    return logits_dict\n\n  def loss(self,\n           labels,\n           logits,\n           features=None,\n           mode=None,\n           regularization_losses=None):\n    \"\"\"Returns regularized training loss. See `base_head.Head` for details.\"\"\"\n    logits_dict = self._check_logits_and_labels(logits, labels)\n    training_losses = []\n    for head in self._heads:\n      training_loss = head.loss(\n          logits=logits_dict[head.name],\n          labels=labels[head.name],\n          features=features,\n          mode=mode)\n      training_losses.append(training_loss)\n\n    training_losses = tuple(training_losses)\n    with ops.name_scope(\n        'merge_losses',\n        values=training_losses + (self._head_weights or tuple())):\n      if self._head_weights:\n        head_weighted_training_losses = []\n        for training_loss, head_weight in zip(training_losses,\n                                              self._head_weights):\n          head_weighted_training_losses.append(\n              tf.math.multiply(training_loss, head_weight))\n        training_losses = head_weighted_training_losses\n      merged_training_loss = tf.math.add_n(training_losses)\n      regularization_loss = tf.math.add_n(\n          regularization_losses) if regularization_losses is not None else None\n      regularized_training_loss = (\n          merged_training_loss + regularization_loss\n          if regularization_loss is not None else merged_training_loss)\n    return regularized_training_loss\n\n  def predictions(self, logits, keys=None):\n    \"\"\"Create predictions. See `base_head.Head` for details.\"\"\"\n    logits_dict = self._check_logits_and_labels(logits)\n    predictions = {}\n    with ops.name_scope('merge_pred'):\n      for head in self._heads:\n        head_preds = head.predictions(logits=logits_dict[head.name])\n        for k, v in six.iteritems(head_preds):\n          predictions[(head.name, k)] = v\n    return predictions\n\n  def metrics(self, regularization_losses=None):\n    \"\"\"Creates metrics. See `base_head.Head` for details.\"\"\"\n    eval_metrics = {}\n    keys = metric_keys.MetricKeys\n    # Add regularization loss metric for multi_head.\n    if regularization_losses is not None:\n      eval_metrics[self._loss_regularization_key] = tf_keras.metrics.Mean(\n          name=keys.LOSS_REGULARIZATION)\n    with ops.name_scope('merge_eval'):\n      # Loss metric is not added by default in each head.\n      for loss_key in self._loss_keys:\n        eval_metrics[loss_key] = tf_keras.metrics.Mean(name=loss_key)\n    return eval_metrics\n\n  def update_metrics(self,\n                     eval_metrics,\n                     features,\n                     logits,\n                     labels,\n                     regularization_losses=None):\n    \"\"\"Updates eval metrics. See `base_head.Head` for details.\"\"\"\n    logits_dict = self._check_logits_and_labels(logits, labels)\n    # Update regularization loss metric\n    if regularization_losses is not None:\n      regularization_loss = tf.math.add_n(regularization_losses)\n      eval_metrics[self._loss_regularization_key].update_state(\n          values=regularization_loss)\n    # Update metrics for each head\n    for i, head in enumerate(self._heads):\n      head_logits = logits_dict[head.name]\n      head_labels = labels[head.name]\n      # Update loss metrics\n      training_loss = head.loss(\n          logits=head_logits, labels=head_labels, features=features)\n      eval_metrics[self._loss_keys[i]].update_state(values=training_loss)\n      # Update existing metrics in each head\n      head_metrics = head.metrics()\n      updated_metrics = head.update_metrics(head_metrics, features, head_logits,\n                                            head_labels)\n      eval_metrics.update(updated_metrics or {})\n    return eval_metrics\n\n  def create_estimator_spec(self,\n                            features,\n                            mode,\n                            logits,\n                            labels=None,\n                            optimizer=None,\n                            trainable_variables=None,\n                            train_op_fn=None,\n                            update_ops=None,\n                            regularization_losses=None):\n    \"\"\"Returns a `model_fn.EstimatorSpec`.\n\n    Args:\n      features: Input `dict` of `Tensor` or `SparseTensor` objects.\n      mode: Estimator's `ModeKeys`.\n      logits: Input `dict` keyed by head name, or logits `Tensor` with shape\n        `[D0, D1, ... DN, logits_dimension]`. For many applications, the\n        `Tensor` shape is `[batch_size, logits_dimension]`. If logits is a\n        `Tensor`, it  will split the `Tensor` along the last dimension and\n        distribute it appropriately among the heads. Check `MultiHead` for\n        examples.\n      labels: Input `dict` keyed by head name. For each head, the label value\n        can be integer or string `Tensor` with shape matching its corresponding\n        `logits`.`labels` is a required argument when `mode` equals `TRAIN` or\n        `EVAL`.\n      optimizer: An `tf_keras.optimizers.Optimizer` instance to optimize the\n        loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss,\n        trainable_variables)`, which updates variables to minimize `loss`.\n      trainable_variables: A list or tuple of `Variable` objects to update to\n        minimize `loss`. In Tensorflow 1.x, by default these are the list of\n        variables collected in the graph under the key\n        `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have\n        collections and GraphKeys, trainable_variables need to be passed\n        explicitly here.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns\n        `train_op`. Used if `optimizer` is `None`.\n      update_ops: A list or tuple of update ops to be run at training time. For\n        example, layers such as BatchNormalization create mean and variance\n        update ops that need to be run at training time. In Tensorflow 1.x,\n        these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x\n        doesn't have collections, update_ops need to be passed explicitly here.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses. These losses are\n        usually expressed as a batch average, so for best results, in each head,\n        users need to use the default `loss_reduction=SUM_OVER_BATCH_SIZE` to\n        avoid scaling errors.  Compared to the regularization losses for each\n        head, this loss is to regularize the merged loss of all heads in multi\n        head, and will be added to the overall training loss of multi head.\n\n    Returns:\n      A `model_fn.EstimatorSpec` instance.\n\n    Raises:\n      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN\n      mode, or if both are set.\n      If `mode` is not in Estimator's `ModeKeys`.\n    \"\"\"\n    with ops.name_scope(self.name, 'multi_head'):\n      logits_dict = self._check_logits_and_labels(logits, labels)\n      # Get all estimator spec.\n      all_estimator_spec = []\n      for head in self._heads:\n        all_estimator_spec.append(\n            head.create_estimator_spec(\n                features=features,\n                mode=mode,\n                logits=logits_dict[head.name],\n                labels=labels[head.name] if labels else None,\n                train_op_fn=_no_op_train_fn))\n      # Predict.\n      predictions = self.predictions(logits)\n      if mode == ModeKeys.PREDICT:\n        export_outputs = self._merge_predict_export_outputs(all_estimator_spec)\n        return model_fn.EstimatorSpec(\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs=export_outputs)\n      loss = self.loss(labels, logits, features, mode, regularization_losses)\n      # Eval.\n      if mode == ModeKeys.EVAL:\n        eval_metrics = self.metrics(regularization_losses=regularization_losses)\n        updated_metrics = self.update_metrics(\n            eval_metrics,\n            features,\n            logits,\n            labels,\n            regularization_losses=regularization_losses)\n        return model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            predictions=predictions,\n            loss=loss,\n            eval_metric_ops=updated_metrics)\n      # Train.\n      if mode == ModeKeys.TRAIN:\n        train_op = base_head.create_estimator_spec_train_op(\n            head_name=self.name,\n            optimizer=optimizer,\n            train_op_fn=train_op_fn,\n            update_ops=update_ops,\n            trainable_variables=trainable_variables,\n            regularized_training_loss=loss,\n            loss_reduction=self.loss_reduction)\n        # Create summary.\n        base_head.create_estimator_spec_summary(loss, regularization_losses)\n        # eval_metrics.\n        eval_metrics = {}\n        for spec in all_estimator_spec:\n          eval_metrics.update(spec.eval_metric_ops or {})\n        # predictions can be used to access the logits in `TRAIN` mode\n        return model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN,\n            loss=loss,\n            train_op=train_op,\n            predictions=predictions,\n            eval_metric_ops=eval_metrics)\n      raise ValueError('mode={} unrecognized'.format(mode))\n\n  def _merge_predict_export_outputs(self, all_estimator_spec):\n    \"\"\"Merges list of `EstimatorSpec` export_outputs for PREDICT.\n\n    For each individual head, its DEFAULT_SERVING_KEY and PREDICT_SERVING_KEY\n    are extracted and merged for `export_outputs` in PREDICT mode of\n    `EstimatorSpec`. By default, the first head is served.\n\n    Args:\n      all_estimator_spec: list of `EstimatorSpec` for the individual heads.\n\n    Returns:\n      A dict of merged export_outputs from all heads for PREDICT.\n    \"\"\"\n    # The first head is used for serving by default.\n    export_outputs = {\n        base_head.DEFAULT_SERVING_KEY:\n            _default_export_output(all_estimator_spec[0].export_outputs,\n                                   self._heads[0].name),\n    }\n    merged_predict_outputs = {}\n    for head, spec in zip(self._heads, all_estimator_spec):\n      for k, v in six.iteritems(spec.export_outputs):\n        # Collect default serving key for export_outputs\n        key = (\n            head.name if k == base_head.DEFAULT_SERVING_KEY else '{}/{}'.format(\n                head.name, k))\n        export_outputs[key] = v\n        # Collect predict serving key for merged_predict_outputs\n        if (k == base_head.PREDICT_SERVING_KEY and\n            isinstance(v, export_output.PredictOutput)):\n          for kp, vp in six.iteritems(v.outputs):\n            merged_predict_outputs['{}/{}'.format(head.name, kp)] = vp\n    export_outputs[base_head.PREDICT_SERVING_KEY] = (\n        export_output.PredictOutput(merged_predict_outputs))\n    return export_outputs\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/multi_head_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for multi_head.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.head import head_utils as test_lib\nfrom tensorflow_estimator.python.estimator.head import multi_head as multi_head_lib\nfrom tensorflow_estimator.python.estimator.head import multi_label_head\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass MultiHeadTest(tf.test.TestCase):\n\n  def test_no_heads(self):\n    with self.assertRaisesRegexp(ValueError,\n                                 r'Must specify heads\\. Given: \\[\\]'):\n      multi_head_lib.MultiHead(heads=[])\n\n  def test_head_name_missing(self):\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    head2 = multi_label_head.MultiLabelHead(n_classes=3)\n    with self.assertRaisesRegexp(ValueError,\n                                 r'All given heads must have name specified\\.'):\n      multi_head_lib.MultiHead([head1, head2])\n\n  def test_head_weights_wrong_size(self):\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')\n    with self.assertRaisesRegexp(\n        ValueError, r'heads and head_weights must have the same size\\. '\n        r'Given len\\(heads\\): 2. Given len\\(head_weights\\): 1\\.'):\n      multi_head_lib.MultiHead([head1, head2], head_weights=[1.])\n\n  def test_name(self):\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')\n    multi_head = multi_head_lib.MultiHead([head1, head2])\n    self.assertEqual('head1_head2', multi_head.name)\n\n  def test_predict_two_heads_logits_dict(self):\n    \"\"\"Tests predict with logits as dict.\"\"\"\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')\n    multi_head = multi_head_lib.MultiHead([head1, head2])\n\n    logits = {\n        'head1': np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32),\n        'head2': np.array([[2., -2., 2.], [-3., 2., -2.]], dtype=np.float32)\n    }\n    expected_probabilities = {\n        'head1': tf.math.sigmoid(logits['head1']),\n        'head2': tf.math.sigmoid(logits['head2']),\n    }\n    pred_keys = prediction_keys.PredictionKeys\n\n    predictions = multi_head.predictions(logits)\n    self.assertAllClose(logits['head1'],\n                        self.evaluate(predictions[('head1', pred_keys.LOGITS)]))\n    self.assertAllClose(logits['head2'],\n                        self.evaluate(predictions[('head2', pred_keys.LOGITS)]))\n    self.assertAllClose(\n        expected_probabilities['head1'],\n        self.evaluate(predictions[('head1', pred_keys.PROBABILITIES)]))\n    self.assertAllClose(\n        expected_probabilities['head2'],\n        self.evaluate(predictions[('head2', pred_keys.PROBABILITIES)]))\n    if tf.executing_eagerly():\n      return\n\n    spec = multi_head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n    self.assertItemsEqual((test_lib._DEFAULT_SERVING_KEY, 'predict', 'head1',\n                           'head1/classification', 'head1/predict', 'head2',\n                           'head2/classification', 'head2/predict'),\n                          spec.export_outputs.keys())\n    # Assert predictions and export_outputs.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(logits['head1'],\n                          predictions[('head1', pred_keys.LOGITS)])\n      self.assertAllClose(logits['head2'],\n                          predictions[('head2', pred_keys.LOGITS)])\n      self.assertAllClose(expected_probabilities['head1'],\n                          predictions[('head1', pred_keys.PROBABILITIES)])\n      self.assertAllClose(expected_probabilities['head2'],\n                          predictions[('head2', pred_keys.PROBABILITIES)])\n\n      self.assertAllClose(\n          expected_probabilities['head1'],\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].scores))\n      self.assertAllClose(expected_probabilities['head1'],\n                          sess.run(spec.export_outputs['head1'].scores))\n      self.assertAllClose(expected_probabilities['head2'],\n                          sess.run(spec.export_outputs['head2'].scores))\n      self.assertAllClose(\n          expected_probabilities['head1'],\n          sess.run(\n              spec.export_outputs['predict'].outputs['head1/probabilities']))\n      self.assertAllClose(\n          expected_probabilities['head2'],\n          sess.run(\n              spec.export_outputs['predict'].outputs['head2/probabilities']))\n      self.assertAllClose(\n          expected_probabilities['head1'],\n          sess.run(\n              spec.export_outputs['head1/predict'].outputs['probabilities']))\n      self.assertAllClose(\n          expected_probabilities['head2'],\n          sess.run(\n              spec.export_outputs['head2/predict'].outputs['probabilities']))\n\n  def test_predict_two_heads_logits_tensor(self):\n    \"\"\"Tests predict with logits as Tensor.\"\"\"\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')\n    multi_head = multi_head_lib.MultiHead([head1, head2])\n\n    logits = np.array([[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]],\n                      dtype=np.float32)\n    expected_logits1 = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)\n    expected_logits2 = np.array([[2., -2., 2.], [-3., 2., -2.]],\n                                dtype=np.float32)\n    expected_probabilities = {\n        'head1': tf.math.sigmoid(expected_logits1),\n        'head2': tf.math.sigmoid(expected_logits2),\n    }\n    pred_keys = prediction_keys.PredictionKeys\n\n    predictions = multi_head.predictions(logits)\n    self.assertAllClose(expected_logits1,\n                        self.evaluate(predictions[('head1', pred_keys.LOGITS)]))\n    self.assertAllClose(expected_logits2,\n                        self.evaluate(predictions[('head2', pred_keys.LOGITS)]))\n    self.assertAllClose(\n        expected_probabilities['head1'],\n        self.evaluate(predictions[('head1', pred_keys.PROBABILITIES)]))\n    self.assertAllClose(\n        expected_probabilities['head2'],\n        self.evaluate(predictions[('head2', pred_keys.PROBABILITIES)]))\n    if tf.executing_eagerly():\n      return\n\n    spec = multi_head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n    self.assertItemsEqual((test_lib._DEFAULT_SERVING_KEY, 'predict', 'head1',\n                           'head1/classification', 'head1/predict', 'head2',\n                           'head2/classification', 'head2/predict'),\n                          spec.export_outputs.keys())\n    # Assert predictions and export_outputs.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(expected_logits1,\n                          predictions[('head1', pred_keys.LOGITS)])\n      self.assertAllClose(expected_logits2,\n                          predictions[('head2', pred_keys.LOGITS)])\n      self.assertAllClose(expected_probabilities['head1'],\n                          predictions[('head1', pred_keys.PROBABILITIES)])\n      self.assertAllClose(expected_probabilities['head2'],\n                          predictions[('head2', pred_keys.PROBABILITIES)])\n\n      self.assertAllClose(\n          expected_probabilities['head1'],\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].scores))\n      self.assertAllClose(expected_probabilities['head1'],\n                          sess.run(spec.export_outputs['head1'].scores))\n      self.assertAllClose(expected_probabilities['head2'],\n                          sess.run(spec.export_outputs['head2'].scores))\n\n  def test_predict_two_heads_logits_tensor_multi_dim(self):\n    \"\"\"Tests predict with multi-dimensional logits of shape [2, 2, 5].\"\"\"\n    head1 = regression_head.RegressionHead(label_dimension=2, name='head1')\n    head2 = regression_head.RegressionHead(label_dimension=3, name='head2')\n    multi_head = multi_head_lib.MultiHead([head1, head2])\n\n    logits = np.array([[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],\n                       [[-1.5, 1., -3., 2., -2.], [-1.5, 1., -3., 2., -2.]]],\n                      dtype=np.float32)\n    expected_logits1 = np.array(\n        [[[-1., 1.], [-1., 1.]], [[-1.5, 1.], [-1.5, 1.]]], dtype=np.float32)\n    expected_logits2 = np.array(\n        [[[2., -2., 2.], [2., -2., 2.]], [[-3., 2., -2.], [-3., 2., -2.]]],\n        dtype=np.float32)\n    pred_keys = prediction_keys.PredictionKeys\n\n    predictions = multi_head.predictions(logits)\n    self.assertAllClose(\n        expected_logits1,\n        self.evaluate(predictions[('head1', pred_keys.PREDICTIONS)]))\n    self.assertAllClose(\n        expected_logits2,\n        self.evaluate(predictions[('head2', pred_keys.PREDICTIONS)]))\n    if tf.executing_eagerly():\n      return\n\n    spec = multi_head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits)\n    self.assertItemsEqual(\n        (test_lib._DEFAULT_SERVING_KEY, 'predict', 'head1', 'head1/regression',\n         'head1/predict', 'head2', 'head2/regression', 'head2/predict'),\n        spec.export_outputs.keys())\n    # Assert predictions and export_outputs.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllClose(expected_logits1,\n                          predictions[('head1', pred_keys.PREDICTIONS)])\n      self.assertAllClose(expected_logits2,\n                          predictions[('head2', pred_keys.PREDICTIONS)])\n\n      self.assertAllClose(\n          expected_logits1,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].value))\n      self.assertAllClose(expected_logits1,\n                          sess.run(spec.export_outputs['head1'].value))\n      self.assertAllClose(expected_logits2,\n                          sess.run(spec.export_outputs['head2'].value))\n\n  def test_eval_two_heads_with_weights(self):\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')\n    multi_head = multi_head_lib.MultiHead([head1, head2], head_weights=[1., 2.])\n\n    logits = {\n        'head1':\n            np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n        'head2':\n            np.array([[20., -20., 20.], [-30., 20., -20.]], dtype=np.float32),\n    }\n    labels = {\n        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),\n        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),\n    }\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # loss = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75\n    # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]\n    # loss = ((20 + 20 + 20) / 3 + (30 + 0 + 0) / 3) / 2 = 15\n    expected_loss_head1 = 8.75\n    expected_loss_head2 = 15.\n    expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2\n    tol = 1e-3\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS + '/head1': expected_loss_head1,\n        keys.LOSS + '/head2': expected_loss_head2,\n        # Average loss over examples.\n        keys.LOSS_MEAN + '/head1': expected_loss_head1,\n        keys.LOSS_MEAN + '/head2': expected_loss_head2,\n        # auc and auc_pr cannot be reliably calculated for only 4-6 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC + '/head1': 0.1667,\n        keys.AUC + '/head2': 0.3333,\n        keys.AUC_PR + '/head1': 0.60228,\n        keys.AUC_PR + '/head2': 0.40152,\n    }\n\n    if tf.executing_eagerly():\n      loss = multi_head.loss(\n          labels, logits, features=features, mode=ModeKeys.EVAL)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n\n      eval_metrics = multi_head.metrics()\n      updated_metrics = multi_head.update_metrics(eval_metrics, features,\n                                                  logits, labels)\n      self.assertItemsEqual(expected_metrics.keys(), updated_metrics.keys())\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics},\n          rtol=tol,\n          atol=tol)\n      return\n\n    spec = multi_head.create_estimator_spec(\n        features=features, mode=ModeKeys.EVAL, logits=logits, labels=labels)\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_train_loss_one_head(self):\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    multi_head = multi_head_lib.MultiHead([head1])\n\n    logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)}\n    labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}\n    loss = multi_head.loss(\n        labels=labels,\n        logits=logits,\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.TRAIN)\n    tol = 1e-3\n    # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2]\n    # (averaged over classes, averaged over examples).\n    # loss = sum(unreduced_loss) / 2 = sum([10, 7.5]) / 2 = 8.75\n    self.assertAllClose(8.75, self.evaluate(loss), rtol=tol, atol=tol)\n\n  def test_train_loss_two_heads_with_weights(self):\n    # Use different example weighting for each head weighting.\n    weights1 = np.array([[1.], [2.]], dtype=np.float32)\n    weights2 = np.array([[2.], [3.]])\n    head1 = multi_label_head.MultiLabelHead(\n        n_classes=2, name='head1', weight_column='weights1')\n    head2 = multi_label_head.MultiLabelHead(\n        n_classes=3, name='head2', weight_column='weights2')\n    multi_head = multi_head_lib.MultiHead([head1, head2], head_weights=[1., 2.])\n\n    logits = {\n        'head1':\n            np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n        'head2':\n            np.array([[20., -20., 20.], [-30., 20., -20.]], dtype=np.float32),\n    }\n    labels = {\n        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),\n        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),\n    }\n    training_loss = multi_head.loss(\n        logits=logits,\n        labels=labels,\n        features={\n            'x': np.array(((42,),), dtype=np.int32),\n            'weights1': weights1,\n            'weights2': weights2\n        },\n        mode=ModeKeys.TRAIN)\n    tol = 1e-3\n    # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]\n    # = [10, 7.5]\n    # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5\n    # head-weighted unreduced_loss = 1 * [10, 7.5]\n    # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]\n    # = [20, 10]\n    # training_loss = (2 * 20 + 3 * 10) / 2 = 35\n    # head-weighted unreduced_loss = 2 * [20, 10]\n    # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5\n    self.assertAllClose(82.5, self.evaluate(training_loss), rtol=tol, atol=tol)\n\n  def test_train_loss_logits_tensor(self):\n    \"\"\"Tests loss with logits Tensor.\"\"\"\n    weights1 = np.array([[1.], [2.]], dtype=np.float32)\n    weights2 = np.array([[2.], [3.]])\n    head1 = multi_label_head.MultiLabelHead(\n        n_classes=2, name='head1', weight_column='weights1')\n    head2 = multi_label_head.MultiLabelHead(\n        n_classes=3, name='head2', weight_column='weights2')\n    multi_head = multi_head_lib.MultiHead([head1, head2], head_weights=[1., 2.])\n\n    logits = np.array(\n        [[-10., 10., 20., -20., 20.], [-15., 10., -30., 20., -20.]],\n        dtype=np.float32)\n    labels = {\n        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),\n        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),\n    }\n    training_loss = multi_head.loss(\n        logits=logits,\n        labels=labels,\n        features={\n            'x': np.array(((42,),), dtype=np.int32),\n            'weights1': weights1,\n            'weights2': weights2\n        },\n        mode=ModeKeys.TRAIN)\n    tol = 1e-3\n    # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]\n    # = [10, 7.5]\n    # training_loss = (1 * 10 + 2 * 7.5) / 2 = 12.5\n    # head-weighted unreduced_loss = 1 * [10, 7.5]\n    # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]\n    # = [20, 10]\n    # training_loss = (2 * 20 + 3 * 10) / 2 = 35\n    # head-weighted unreduced_loss = 2 * [20, 10]\n    # head-weighted training_loss = 1 * 12.5 + 2 * 35 = 82.5\n    self.assertAllClose(82.5, self.evaluate(training_loss), rtol=tol, atol=tol)\n\n  def test_train_loss_logits_tensor_wrong_shape(self):\n    \"\"\"Tests loss with a logits Tensor of the wrong shape.\"\"\"\n    weights1 = np.array([[1.], [2.]], dtype=np.float32)\n    weights2 = np.array([[2.], [3.]])\n    head1 = multi_label_head.MultiLabelHead(\n        n_classes=2, name='head1', weight_column='weights1')\n    head2 = multi_label_head.MultiLabelHead(\n        n_classes=3, name='head2', weight_column='weights2')\n    multi_head = multi_head_lib.MultiHead([head1, head2], head_weights=[1., 2.])\n\n    # logits tensor is 2x6 instead of 2x5\n    logits = np.array(\n        [[-10., 10., 20., -20., 20., 70.], [-15., 10., -30., 20., -20., 80.]],\n        dtype=np.float32)\n    labels = {\n        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),\n        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),\n    }\n    with self.assertRaisesRegexp(ValueError, r'Could not split logits'):\n      multi_head.loss(\n          features={\n              'x': np.array(((42,),), dtype=np.int32),\n              'weights1': weights1,\n              'weights2': weights2\n          },\n          mode=ModeKeys.TRAIN,\n          logits=logits,\n          labels=labels)\n\n  def test_train_loss_logits_tensor_multi_dim(self):\n    \"\"\"Tests loss with multi-dimensional logits of shape [2, 2, 5].\"\"\"\n    head1 = regression_head.RegressionHead(label_dimension=2, name='head1')\n    head2 = regression_head.RegressionHead(label_dimension=3, name='head2')\n    multi_head = multi_head_lib.MultiHead([head1, head2])\n\n    logits = np.array([[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],\n                       [[-1.5, 1.5, -2., 2., -2.], [-1.5, 1.5, -2., 2., -2.]]],\n                      dtype=np.float32)\n    labels = {\n        'head1':\n            np.array([[[1., 0.], [1., 0.]], [[1.5, 1.5], [1.5, 1.5]]],\n                     dtype=np.float32),\n        'head2':\n            np.array(\n                [[[0., 1., 0.], [0., 1., 0.]], [[2., 2., 0.], [2., 2., 0.]]],\n                dtype=np.float32),\n    }\n    # Loss for the first head:\n    # loss1 = ((1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 +\n    #          (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2) / 8\n    #       = 3.5\n    # Loss for the second head:\n    # loss2 = ((0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 +\n    #          (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2) / 12\n    #       = 6.167\n    expected_training_loss = 3.5 + 6.167\n\n    training_loss = multi_head.loss(\n        logits=logits, labels=labels, features={}, mode=ModeKeys.TRAIN)\n    tol = 1e-3\n    self.assertAllClose(\n        expected_training_loss,\n        self.evaluate(training_loss),\n        rtol=tol,\n        atol=tol)\n\n  def test_train_loss_logits_tensor_multi_dim_wrong_shape(self):\n    \"\"\"Tests loss with a multi-dimensional logits tensor of the wrong shape.\"\"\"\n    head1 = regression_head.RegressionHead(label_dimension=2, name='head1')\n    head2 = regression_head.RegressionHead(label_dimension=3, name='head2')\n    multi_head = multi_head_lib.MultiHead([head1, head2])\n\n    # logits tensor is 2x2x4 instead of 2x2x5\n    logits = np.array([[[-1., 1., 2., -2.], [-1., 1., 2., -2.]],\n                       [[-1.5, 1.5, -2., 2.], [-1.5, 1.5, -2., 2.]]],\n                      dtype=np.float32)\n    labels = {\n        'head1':\n            np.array([[[1., 0.], [1., 0.]], [[1.5, 1.5], [1.5, 1.5]]],\n                     dtype=np.float32),\n        'head2':\n            np.array(\n                [[[0., 1., 0.], [0., 1., 0.]], [[2., 2., 0.], [2., 2., 0.]]],\n                dtype=np.float32),\n    }\n    with self.assertRaisesRegexp(ValueError, r'Could not split logits'):\n      multi_head.loss(\n          features={}, mode=ModeKeys.TRAIN, logits=logits, labels=labels)\n\n  def test_train_one_head(self):\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    multi_head = multi_head_lib.MultiHead([head1])\n\n    logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)}\n    expected_probabilities = {\n        'head1': tf.math.sigmoid(logits['head1']),\n    }\n    labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # loss = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75\n    expected_loss = 8.75\n    tol = 1e-3\n    loss = multi_head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertAllClose(expected_loss, self.evaluate(loss), rtol=tol, atol=tol)\n    if tf.executing_eagerly():\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=3)\n      ])\n\n    spec = multi_head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str, predictions = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op,\n           spec.predictions))\n      self.assertAllClose(\n          logits['head1'],\n          predictions[('head1', prediction_keys.PredictionKeys.LOGITS)])\n      self.assertAllClose(\n          expected_probabilities['head1'],\n          predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)])\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              metric_keys.MetricKeys.LOSS + '/head1': expected_loss,\n          }, summary_str, tol)\n\n  def test_train_one_head_with_optimizer(self):\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    multi_head = multi_head_lib.MultiHead([head1])\n\n    logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)}\n    labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # loss = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75\n    expected_loss = 8.75\n    tol = 1e-3\n    loss = multi_head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertAllClose(expected_loss, self.evaluate(loss), rtol=tol, atol=tol)\n    if tf.executing_eagerly():\n      return\n\n    expected_train_result = 'my_train_op'\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params\n        return [\n            tf.strings.join([\n                tf.constant(expected_train_result),\n                tf.strings.as_string(loss, precision=3)\n            ])\n        ]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    spec = multi_head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        optimizer=_Optimizer('my_optimizer'),\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),\n          train_result)\n\n  def test_train_two_heads_with_weights(self):\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')\n    multi_head = multi_head_lib.MultiHead([head1, head2], head_weights=[1., 2.])\n\n    logits = {\n        'head1':\n            np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n        'head2':\n            np.array([[20., -20., 20.], [-30., 20., -20.]], dtype=np.float32),\n    }\n    expected_probabilities = {\n        'head1': tf.math.sigmoid(logits['head1']),\n        'head2': tf.math.sigmoid(logits['head2']),\n    }\n    labels = {\n        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),\n        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),\n    }\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # loss = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75\n    # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]\n    # loss = ((20 + 20 + 20) / 3 + (30 + 0 + 0) / 3) / 2 = 15\n    # Average over classes, weighted sum over batch and heads.\n    expected_loss_head1 = 8.75\n    expected_loss_head2 = 15.0\n    expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2\n    tol = 1e-3\n    loss = multi_head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertAllClose(expected_loss, self.evaluate(loss), rtol=tol, atol=tol)\n    if tf.executing_eagerly():\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=3)\n      ])\n\n    spec = multi_head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn)\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str, predictions = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op,\n           spec.predictions))\n      self.assertAllClose(\n          logits['head1'],\n          predictions[('head1', prediction_keys.PredictionKeys.LOGITS)])\n      self.assertAllClose(\n          expected_probabilities['head1'],\n          predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)])\n      self.assertAllClose(\n          logits['head2'],\n          predictions[('head2', prediction_keys.PredictionKeys.LOGITS)])\n      self.assertAllClose(\n          expected_probabilities['head2'],\n          predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)])\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS: expected_loss,\n              metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1,\n              metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2,\n          }, summary_str, tol)\n\n  def test_train_with_regularization_losses(self):\n    head1 = multi_label_head.MultiLabelHead(n_classes=2, name='head1')\n    head2 = multi_label_head.MultiLabelHead(n_classes=3, name='head2')\n    multi_head = multi_head_lib.MultiHead([head1, head2], head_weights=[1., 2.])\n\n    logits = {\n        'head1':\n            np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n        'head2':\n            np.array([[20., -20., 20.], [-30., 20., -20.]], dtype=np.float32),\n    }\n    expected_probabilities = {\n        'head1': tf.math.sigmoid(logits['head1']),\n        'head2': tf.math.sigmoid(logits['head2']),\n    }\n    labels = {\n        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),\n        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),\n    }\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    regularization_losses = [1.5, 0.5]\n\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # loss1 = ((10 + 10) / 2 + (15 + 0) / 2) / 2 = 8.75\n    # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]\n    # loss2 = ((20 + 20 + 20) / 3 + (30 + 0 + 0) / 3) / 2 = 15\n    # Average over classes, weighted sum over batch and heads.\n    # weights = [1., 2.]\n    # merged_training_loss = 1. * loss1 + 2. * loss2\n    # training_loss = merged_training_loss + regularization_loss\n    #               = 1. * loss1 + 2. * loss2 + sum([1.5, 0.5])\n    expected_loss_head1 = 8.75\n    expected_loss_head2 = 15.0\n    expected_regularization_loss = 2.\n    # training loss.\n    expected_loss = (1. * expected_loss_head1 + 2. * expected_loss_head2 +\n                     expected_regularization_loss)\n    tol = 1e-3\n    loss = multi_head.loss(\n        logits=logits,\n        labels=labels,\n        features=features,\n        mode=ModeKeys.TRAIN,\n        regularization_losses=regularization_losses)\n    self.assertAllClose(expected_loss, self.evaluate(loss), rtol=tol, atol=tol)\n    if tf.executing_eagerly():\n      return\n\n    keys = metric_keys.MetricKeys\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=3)\n      ])\n\n    spec = multi_head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        regularization_losses=regularization_losses)\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str, predictions = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op,\n           spec.predictions))\n      self.assertAllClose(\n          logits['head1'],\n          predictions[('head1', prediction_keys.PredictionKeys.LOGITS)])\n      self.assertAllClose(\n          expected_probabilities['head1'],\n          predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)])\n      self.assertAllClose(\n          logits['head2'],\n          predictions[('head2', prediction_keys.PredictionKeys.LOGITS)])\n      self.assertAllClose(\n          expected_probabilities['head2'],\n          predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)])\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(\n          self, {\n              keys.LOSS_REGULARIZATION: expected_regularization_loss,\n              keys.LOSS: expected_loss,\n              keys.LOSS + '/head1': expected_loss_head1,\n              keys.LOSS + '/head2': expected_loss_head2,\n          }, summary_str, tol)\n\n\n@test_util.deprecated_graph_mode_only\nclass MultiHeadForEstimator(tf.test.TestCase):\n  \"\"\"Tests for create_estimator_spec running in Graph mode only.\"\"\"\n\n  def test_loss_reduction_must_be_same(self):\n    \"\"\"Tests the loss reduction must be the same for different heads.\"\"\"\n    head1 = multi_label_head.MultiLabelHead(\n        n_classes=2,\n        name='head1',\n        loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE)\n    head2 = multi_label_head.MultiLabelHead(\n        n_classes=3, name='head2', loss_reduction=tf.losses.Reduction.AUTO)\n    multi_head = multi_head_lib.MultiHead([head1, head2])\n    logits = {\n        'head1':\n            np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n        'head2':\n            np.array([[20., -20., 20.], [-30., 20., -20.]], dtype=np.float32),\n    }\n    labels = {\n        'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),\n        'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),\n    }\n    with self.assertRaisesRegexp(ValueError, 'must be the same'):\n      multi_head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=logits,\n          labels=labels)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/multi_label_head.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Multi label head.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import lookup_ops\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.head import base_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v2\n\n\n@estimator_export('estimator.MultiLabelHead')\nclass MultiLabelHead(base_head.Head):\n  \"\"\"Creates a `Head` for multi-label classification.\n\n  Multi-label classification handles the case where each example may have zero\n  or more associated labels, from a discrete set. This is distinct from\n  `MultiClassHead` which has exactly one label per example.\n\n  Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over\n  the batch. Namely, if the input logits have shape `[batch_size, n_classes]`,\n  the loss is the average over `n_classes` and the weighted sum over\n  `batch_size`.\n\n  The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many\n  applications, the shape is `[batch_size, n_classes]`.\n\n  Labels can be:\n\n  * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`\n  * An integer `SparseTensor` of class indices. The `dense_shape` must be\n    `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.\n  * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`\n    must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary` or a\n    multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`.\n\n  If `weight_column` is specified, weights must be of shape\n  `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.\n\n  Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or\n  `(labels, logits, features)` as arguments and returns unreduced loss with\n  shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with\n  shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies\n  `label_vocabulary` to the input labels before passing them to `loss_fn`.\n\n  Usage:\n\n  >>> n_classes = 2\n  >>> head = tf.estimator.MultiLabelHead(n_classes)\n  >>> logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)\n  >>> labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n  >>> features = {'x': np.array([[41], [42]], dtype=np.int32)}\n  >>> # expected_loss = sum(_sigmoid_cross_entropy(labels, logits)) / batch_size\n  >>> #               = sum(1.31326169, 0.9514133) / 2 = 1.13\n  >>> loss = head.loss(labels, logits, features=features)\n  >>> print('{:.2f}'.format(loss.numpy()))\n  1.13\n  >>> eval_metrics = head.metrics()\n  >>> updated_metrics = head.update_metrics(\n  ...   eval_metrics, features, logits, labels)\n  >>> for k in sorted(updated_metrics):\n  ...  print('{} : {:.2f}'.format(k, updated_metrics[k].result().numpy()))\n  auc : 0.33\n  auc_precision_recall : 0.77\n  average_loss : 1.13\n  >>> preds = head.predictions(logits)\n  >>> print(preds['logits'])\n  tf.Tensor(\n    [[-1.   1. ]\n     [-1.5  1.5]], shape=(2, 2), dtype=float32)\n\n  Usage with a canned estimator:\n\n  ```python\n  my_head = tf.estimator.MultiLabelHead(n_classes=3)\n  my_estimator = tf.estimator.DNNEstimator(\n      head=my_head,\n      hidden_units=...,\n      feature_columns=...)\n  ```\n\n  It can also be used with a custom `model_fn`. Example:\n\n  ```python\n  def _my_model_fn(features, labels, mode):\n    my_head = tf.estimator.MultiLabelHead(n_classes=3)\n    logits = tf_keras.Model(...)(features)\n\n    return my_head.create_estimator_spec(\n        features=features,\n        mode=mode,\n        labels=labels,\n        optimizer=tf_keras.optimizers.Adagrad(lr=0.1),\n        logits=logits)\n\n  my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)\n  ```\n\n  Args:\n    n_classes: Number of classes, must be greater than 1 (for 1 class, use\n      `BinaryClassHead`).\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example.  Per-class weighting is not\n      supported.\n    thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision\n      and recall metrics are evaluated for each threshold value. The threshold\n      is applied to the predicted probabilities, i.e. above the threshold is\n      `true`, below is `false`.\n    label_vocabulary: A list of strings represents possible label values. If it\n      is not given, that means labels are already encoded as integer within [0,\n      n_classes) or multi-hot Tensor. If given, labels must be SparseTensor\n      `string` type and have any value in `label_vocabulary`. Also there will be\n      errors if vocabulary is not provided and labels are string.\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Decides how to\n      reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely\n      weighted sum of losses divided by batch size.\n    loss_fn: Optional loss function.\n    classes_for_class_based_metrics: List of integer class IDs or string class\n      names for which per-class metrics are evaluated. If integers, all must be\n      in the range `[0, n_classes - 1]`. If strings, all must be in\n      `label_vocabulary`.\n    name: Name of the head. If provided, summary and metrics keys will be\n      suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops.\n  \"\"\"\n\n  def __init__(self,\n               n_classes,\n               weight_column=None,\n               thresholds=None,\n               label_vocabulary=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               loss_fn=None,\n               classes_for_class_based_metrics=None,\n               name=None):\n    if n_classes is None or n_classes < 2:\n      raise ValueError('n_classes must be > 1 for multi-label classification. '\n                       'Given: {}'.format(n_classes))\n    thresholds = tuple(thresholds) if thresholds else tuple()\n    for threshold in thresholds:\n      if (threshold <= 0.0) or (threshold >= 1.0):\n        raise ValueError(\n            'thresholds must be in (0, 1) range. Given: {}'.format(threshold))\n    if label_vocabulary is not None:\n      if not isinstance(label_vocabulary, (list, tuple)):\n        raise ValueError('label_vocabulary must be a list or tuple. '\n                         'Given type: {}'.format(type(label_vocabulary)))\n      if len(label_vocabulary) != n_classes:\n        raise ValueError('Length of label_vocabulary must be n_classes ({}). '\n                         'Given: {}'.format(n_classes, len(label_vocabulary)))\n\n    if loss_fn:\n      base_head.validate_loss_fn_args(loss_fn)\n    base_head.validate_loss_reduction(loss_reduction)\n    if classes_for_class_based_metrics:\n      classes_for_class_based_metrics = tuple(classes_for_class_based_metrics)\n      if isinstance(classes_for_class_based_metrics[0], six.string_types):\n        if not label_vocabulary:\n          raise ValueError('label_vocabulary must be provided when '\n                           'classes_for_class_based_metrics are strings.')\n        class_ids = []\n        for class_string in classes_for_class_based_metrics:\n          class_ids.append(label_vocabulary.index(class_string))\n        classes_for_class_based_metrics = tuple(class_ids)\n      else:\n        for class_id in classes_for_class_based_metrics:\n          if (class_id < 0) or (class_id >= n_classes):\n            raise ValueError(\n                'All classes_for_class_based_metrics must be in range [0, {}]. '\n                'Given: {}'.format(n_classes - 1, class_id))\n    else:\n      classes_for_class_based_metrics = tuple()\n    self._n_classes = n_classes\n    self._weight_column = weight_column\n    self._thresholds = thresholds\n    self._label_vocabulary = label_vocabulary\n    self._loss_reduction = loss_reduction\n    self._loss_fn = loss_fn\n    self._classes_for_class_based_metrics = classes_for_class_based_metrics\n    self._name = name\n    # Metric keys.\n    keys = metric_keys.MetricKeys\n    self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)\n    self._auc_key = self._summary_key(keys.AUC)\n    self._auc_pr_key = self._summary_key(keys.AUC_PR)\n    self._loss_regularization_key = self._summary_key(keys.LOSS_REGULARIZATION)\n    accuracy_keys = []\n    precision_keys = []\n    recall_keys = []\n    for threshold in self._thresholds:\n      accuracy_keys.append(\n          self._summary_key(keys.ACCURACY_AT_THRESHOLD % threshold))\n      precision_keys.append(\n          self._summary_key(keys.PRECISION_AT_THRESHOLD % threshold))\n      recall_keys.append(\n          self._summary_key(keys.RECALL_AT_THRESHOLD % threshold))\n    self._accuracy_keys = tuple(accuracy_keys)\n    self._precision_keys = tuple(precision_keys)\n    self._recall_keys = tuple(recall_keys)\n    prob_keys = []\n    auc_keys = []\n    auc_pr_keys = []\n    for class_id in self._classes_for_class_based_metrics:\n      if self._label_vocabulary is None:\n        prob_key = keys.PROBABILITY_MEAN_AT_CLASS % class_id\n        auc_key = keys.AUC_AT_CLASS % class_id\n        auc_pr_key = keys.AUC_PR_AT_CLASS % class_id\n      else:\n        prob_key = (\n            keys.PROBABILITY_MEAN_AT_NAME % self._label_vocabulary[class_id])\n        auc_key = keys.AUC_AT_NAME % self._label_vocabulary[class_id]\n        auc_pr_key = keys.AUC_PR_AT_NAME % self._label_vocabulary[class_id]\n      prob_keys.append(self._summary_key(prob_key))\n      auc_keys.append(self._summary_key(auc_key))\n      auc_pr_keys.append(self._summary_key(auc_pr_key))\n    self._prob_keys = tuple(prob_keys)\n    self._auc_keys = tuple(auc_keys)\n    self._auc_pr_keys = tuple(auc_pr_keys)\n\n  @property\n  def name(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._name\n\n  @property\n  def logits_dimension(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._n_classes\n\n  @property\n  def loss_reduction(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._loss_reduction\n\n  # An attribute for lookup table. Note that for Graph execution, the lookup\n  # table is created on demand to make sure the lookup table is in the same\n  # graph as its input tensors for `train` and `eval` of Estimator (as Estimator\n  # re-creates graphs for `train`, `eval` and `predict`).\n  _cached_class_id_table = None\n\n  @property\n  def _class_id_table(self):\n    \"\"\"Creates a lookup table for class_id.\n\n    In eager execution, this lookup table will be lazily created on the first\n    call of `self._class_id_table`, and cached for later use; In graph\n    execution, it will be created on demand.\n\n    Returns:\n      A hash table for lookup.\n    \"\"\"\n    if self._cached_class_id_table is None or not tf.executing_eagerly():\n      self._cached_class_id_table = lookup_ops.index_table_from_tensor(\n          vocabulary_list=tuple(self._label_vocabulary), name='class_id_lookup')\n    return self._cached_class_id_table\n\n  def _processed_labels(self, logits, labels):\n    \"\"\"Converts labels to integer id space.\"\"\"\n    if labels is None:\n      raise ValueError(base_head._LABEL_NONE_ERR_MSG)  # pylint:disable=protected-access\n    if isinstance(labels, tf.sparse.SparseTensor):\n      label_values = labels.values\n      if labels.dtype == tf.dtypes.string:\n        label_ids_values = self._class_id_table.lookup(label_values)\n        label_ids = tf.sparse.SparseTensor(\n            indices=labels.indices,\n            values=label_ids_values,\n            dense_shape=labels.dense_shape)\n        processed_labels = tf.sparse.to_indicator(label_ids, self._n_classes)\n      else:\n        if not label_values.dtype.is_integer:\n          raise ValueError(\n              'Labels dtype should be integer. Instead got {}.'.format(\n                  label_values.dtype))\n        err_msg = (r'labels must be an integer SparseTensor with values in '\n                   r'[0, {})'.format(self._n_classes))\n        label_values = base_head.check_label_range(\n            labels.values, self._n_classes, message=err_msg)\n        if tf.executing_eagerly():\n          processed_labels = tf.sparse.to_indicator(labels, self._n_classes)\n        else:\n          with tf.control_dependencies([label_values]):\n            processed_labels = tf.sparse.to_indicator(labels, self._n_classes)\n      processed_labels = tf.cast(processed_labels, dtype=tf.dtypes.int64)\n    else:\n      err_msg = (\n          r'labels must be an integer indicator Tensor with values in [0, 1]')\n      processed_labels = base_head.check_label_range(labels, 2, message=err_msg)\n\n    return base_head.check_dense_labels_match_logits_and_reshape(\n        labels=processed_labels,\n        logits=logits,\n        expected_labels_dimension=self.logits_dimension)\n\n  def _unweighted_loss_and_weights(self, logits, processed_labels, features):\n    \"\"\"Computes loss spec.\"\"\"\n    if self._loss_fn:\n      unweighted_loss = base_head.call_loss_fn(\n          loss_fn=self._loss_fn,\n          labels=processed_labels,\n          logits=logits,\n          features=features,\n          expected_loss_dim=1)\n    else:\n      unweighted_loss = tf.compat.v1.losses.sigmoid_cross_entropy(\n          multi_class_labels=processed_labels,\n          logits=logits,\n          reduction=tf.compat.v1.losses.Reduction.NONE)\n      # Averages loss over classes.\n      unweighted_loss = tf.math.reduce_mean(\n          unweighted_loss, axis=-1, keepdims=True)\n    weights = base_head.get_weights_and_check_match_logits(\n        features=features, weight_column=self._weight_column, logits=logits)\n    return unweighted_loss, weights\n\n  def loss(self,\n           labels,\n           logits,\n           features=None,\n           mode=None,\n           regularization_losses=None):\n    \"\"\"Returns regularized training loss. See `base_head.Head` for details.\"\"\"\n    del mode  # Unused for this head.\n    with ops.name_scope(\n        'losses', values=(logits, labels, regularization_losses, features)):\n      logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n      processed_labels = self._processed_labels(logits, labels)\n      unweighted_loss, weights = self._unweighted_loss_and_weights(\n          logits, processed_labels, features)\n      training_loss = tf_keras_v2.__internal__.losses.compute_weighted_loss(\n          unweighted_loss,\n          sample_weight=weights,\n          reduction=self._loss_reduction)\n      regularization_loss = tf.math.add_n(\n          regularization_losses) if regularization_losses is not None else None\n      regularized_training_loss = (\n          training_loss + regularization_loss\n          if regularization_loss is not None else training_loss)\n    return regularized_training_loss\n\n  def predictions(self, logits, keys=None):\n    \"\"\"Return predictions based on keys.\n\n    See `base_head.Head` for details.\n\n    Args:\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.\n        For many applications, the shape is `[batch_size, logits_dimension]`.\n      keys: a list of prediction keys. Key can be either the class variable\n        of prediction_keys.PredictionKeys or its string value, such as:\n          prediction_keys.PredictionKeys.LOGITS or 'logits'.\n\n    Returns:\n      A dict of predictions.\n    \"\"\"\n    pred_keys = prediction_keys.PredictionKeys\n    valid_keys = [pred_keys.LOGITS, pred_keys.PROBABILITIES, pred_keys.CLASSES]\n    if keys:\n      base_head.check_prediction_keys(keys, valid_keys)\n    else:\n      keys = valid_keys\n    logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n    predictions = {}\n    with ops.name_scope('predictions', values=(logits,)):\n      if pred_keys.LOGITS in keys:\n        predictions[pred_keys.LOGITS] = logits\n      if pred_keys.PROBABILITIES in keys:\n        probabilities = tf.math.sigmoid(logits, name=pred_keys.PROBABILITIES)\n        predictions[pred_keys.PROBABILITIES] = probabilities\n      if pred_keys.CLASSES in keys:\n        predictions[pred_keys.CLASSES] = base_head.all_classes(\n            logits, self._n_classes, self._label_vocabulary)\n\n      return predictions\n\n  def metrics(self, regularization_losses=None):\n    \"\"\"Creates metrics. See `base_head.Head` for details.\"\"\"\n    keys = metric_keys.MetricKeys\n    with ops.name_scope(None, 'metrics', (regularization_losses,)):\n      # Mean metric.\n      eval_metrics = {}\n      eval_metrics[self._loss_mean_key] = tf_keras.metrics.Mean(\n          name=keys.LOSS_MEAN)\n      # The default summation_method is \"interpolation\" in the AUC metric.\n      eval_metrics[self._auc_key] = tf_keras.metrics.AUC(name=keys.AUC)\n      eval_metrics[self._auc_pr_key] = tf_keras.metrics.AUC(\n          curve='PR', name=keys.AUC_PR)\n      if regularization_losses is not None:\n        eval_metrics[self._loss_regularization_key] = tf_keras.metrics.Mean(\n            name=keys.LOSS_REGULARIZATION)\n      for i, threshold in enumerate(self._thresholds):\n        eval_metrics[self._accuracy_keys[i]] = tf_keras.metrics.BinaryAccuracy(\n            name=self._accuracy_keys[i], threshold=threshold)\n        eval_metrics[self._precision_keys[i]] = (\n            tf_keras.metrics.Precision(\n                name=self._precision_keys[i], thresholds=threshold))\n        eval_metrics[self._recall_keys[i]] = tf_keras.metrics.Recall(\n            name=self._recall_keys[i], thresholds=threshold)\n      for i in range(len(self._classes_for_class_based_metrics)):\n        eval_metrics[self._prob_keys[i]] = tf_keras.metrics.Mean(\n            name=self._prob_keys[i])\n        eval_metrics[self._auc_keys[i]] = tf_keras.metrics.AUC(\n            name=self._auc_keys[i])\n        eval_metrics[self._auc_pr_keys[i]] = tf_keras.metrics.AUC(\n            curve='PR', name=self._auc_pr_keys[i])\n    return eval_metrics\n\n  def update_metrics(self,\n                     eval_metrics,\n                     features,\n                     logits,\n                     labels,\n                     regularization_losses=None):\n    \"\"\"Updates eval metrics. See `base_head.Head` for details.\"\"\"\n    logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n    processed_labels = self._processed_labels(logits, labels)\n    unweighted_loss, weights = self._unweighted_loss_and_weights(\n        logits, processed_labels, features)\n    prob_key = prediction_keys.PredictionKeys.PROBABILITIES\n    predictions = self.predictions(logits, [prob_key])\n    probabilities = predictions[prob_key]\n\n    # Update metrics.\n    eval_metrics[self._loss_mean_key].update_state(\n        values=unweighted_loss, sample_weight=weights)\n    eval_metrics[self._auc_key].update_state(\n        y_true=processed_labels, y_pred=probabilities, sample_weight=weights)\n    eval_metrics[self._auc_pr_key].update_state(\n        y_true=processed_labels, y_pred=probabilities, sample_weight=weights)\n    if regularization_losses is not None:\n      regularization_loss = tf.math.add_n(regularization_losses)\n      eval_metrics[self._loss_regularization_key].update_state(\n          values=regularization_loss)\n    for i in range(len(self._thresholds)):\n      eval_metrics[self._accuracy_keys[i]].update_state(\n          y_true=processed_labels, y_pred=probabilities, sample_weight=weights)\n      eval_metrics[self._precision_keys[i]].update_state(\n          y_true=processed_labels, y_pred=probabilities, sample_weight=weights)\n      eval_metrics[self._recall_keys[i]].update_state(\n          y_true=processed_labels, y_pred=probabilities, sample_weight=weights)\n    for i, class_id in enumerate(self._classes_for_class_based_metrics):\n      batch_rank = tf.rank(probabilities) - 1\n      begin = tf.concat(\n          [tf.zeros([batch_rank], dtype=tf.dtypes.int32), [class_id]], axis=0)\n      size = tf.concat([-1 * tf.ones([batch_rank], dtype=tf.dtypes.int32), [1]],\n                       axis=0)\n      class_probabilities = tf.slice(probabilities, begin=begin, size=size)\n      class_labels = tf.slice(processed_labels, begin=begin, size=size)\n      base_head.update_metric_with_broadcast_weights(\n          eval_metrics[self._prob_keys[i]], class_probabilities, weights)\n      eval_metrics[self._auc_keys[i]].update_state(\n          y_true=class_labels,\n          y_pred=class_probabilities,\n          sample_weight=weights)\n      eval_metrics[self._auc_pr_keys[i]].update_state(\n          y_true=class_labels,\n          y_pred=class_probabilities,\n          sample_weight=weights)\n    return eval_metrics\n\n  def _create_tpu_estimator_spec(self,\n                                 features,\n                                 mode,\n                                 logits,\n                                 labels=None,\n                                 optimizer=None,\n                                 trainable_variables=None,\n                                 train_op_fn=None,\n                                 update_ops=None,\n                                 regularization_losses=None):\n    \"\"\"Returns an `model_fn._TPUEstimatorSpec`.\n\n    Args:\n      features: Input `dict` of `Tensor` or `SparseTensor` objects.\n      mode: Estimator's `ModeKeys`.\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, n_classes]`. For many\n        applications, the shape is `[batch_size, n_classes]`.\n      labels: Labels with shape matching `logits`. Can be multi-hot `Tensor`\n        with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with\n        `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when\n        `mode` equals `TRAIN` or `EVAL`.\n      optimizer: An `tf_keras.optimizers.Optimizer` instance to optimize the\n        loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss,\n        trainable_variables)`, which updates variables to minimize\n        `loss`.able_variables)`, which updates variables to minimize `loss`.\n      trainable_variables: A list or tuple of `Variable` objects to update to\n        minimize `loss`. In Tensorflow 1.x, by default these are the list of\n        variables collected in the graph under the key\n        `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have\n        collections and GraphKeys, trainable_variables need to be passed\n        explicitly here.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns\n        `train_op`. Used if `optimizer` is `None`.\n      update_ops: A list or tuple of update ops to be run at training time. For\n        example, layers such as BatchNormalization create mean and variance\n        update ops that need to be run at training time. In Tensorflow 1.x,\n        these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x\n        doesn't have collections, update_ops need to be passed explicitly here.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses. These losses are\n        usually expressed as a batch average, so for best results users need to\n        set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid\n        scaling errors.\n\n    Returns:\n      `model_fn._TPUEstimatorSpec`.\n    Raises:\n      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN\n        mode, or if both are set.\n    \"\"\"\n    with ops.name_scope(self._name, 'head'):\n      # Predict.\n      pred_keys = prediction_keys.PredictionKeys\n      predictions = self.predictions(logits)\n      if mode == ModeKeys.PREDICT:\n        probabilities = predictions[pred_keys.PROBABILITIES]\n        classifier_output = base_head.classification_output(\n            scores=probabilities,\n            n_classes=self._n_classes,\n            label_vocabulary=self._label_vocabulary)\n        return model_fn._TPUEstimatorSpec(  # pylint:disable=protected-access\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs={\n                base_head.DEFAULT_SERVING_KEY: classifier_output,\n                base_head.CLASSIFY_SERVING_KEY: classifier_output,\n                base_head.PREDICT_SERVING_KEY: (\n                    export_output.PredictOutput(predictions))\n            })\n\n      regularized_training_loss = self.loss(\n          logits=logits,\n          labels=labels,\n          features=features,\n          mode=mode,\n          regularization_losses=regularization_losses)\n      # Eval.\n      if mode == ModeKeys.EVAL:\n        eval_metrics = self.metrics(regularization_losses=regularization_losses)\n        return model_fn._TPUEstimatorSpec(  # pylint:disable=protected-access\n            mode=ModeKeys.EVAL,\n            predictions=predictions,\n            loss=regularized_training_loss,\n            eval_metrics=base_head.create_eval_metrics_tuple(\n                self.update_metrics, {\n                    'eval_metrics': eval_metrics,\n                    'features': features,\n                    'logits': logits,\n                    'labels': labels,\n                    'regularization_losses': regularization_losses\n                }))\n      # Train.\n      train_op = base_head.create_estimator_spec_train_op(\n          head_name=self._name,\n          optimizer=optimizer,\n          train_op_fn=train_op_fn,\n          update_ops=update_ops,\n          trainable_variables=trainable_variables,\n          regularized_training_loss=regularized_training_loss,\n          loss_reduction=self._loss_reduction)\n    # Create summary.\n    base_head.create_estimator_spec_summary(\n        regularized_training_loss=regularized_training_loss,\n        regularization_losses=regularization_losses,\n        summary_key_fn=self._summary_key)\n    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n        mode=ModeKeys.TRAIN,\n        predictions=predictions,\n        loss=regularized_training_loss,\n        train_op=train_op)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/multi_label_head_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for multi_label_head.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.head import head_utils as test_lib\nfrom tensorflow_estimator.python.estimator.head import multi_label_head as head_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\ndef _sigmoid_cross_entropy(labels, logits):\n  \"\"\"Returns sigmoid cross entropy averaged over classes.\"\"\"\n  sigmoid_logits = 1 / (1 + np.exp(-logits))\n  unreduced_result = (-labels * np.log(sigmoid_logits) -\n                      (1 - labels) * np.log(1 - sigmoid_logits))\n  # Mean over classes\n  return np.mean(unreduced_result, axis=-1, keepdims=True)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass MultiLabelHead(tf.test.TestCase):\n\n  def test_n_classes_is_none(self):\n    with self.assertRaisesRegexp(\n        ValueError,\n        r'n_classes must be > 1 for multi-label classification\\. Given: None'):\n      head_lib.MultiLabelHead(n_classes=None)\n\n  def test_n_classes_is_1(self):\n    with self.assertRaisesRegexp(\n        ValueError,\n        r'n_classes must be > 1 for multi-label classification\\. Given: 1'):\n      head_lib.MultiLabelHead(n_classes=1)\n\n  def test_threshold_too_small(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'thresholds must be in \\(0, 1\\) range\\. Given: 0\\.0'):\n      head_lib.MultiLabelHead(n_classes=2, thresholds=[0., 0.5])\n\n  def test_threshold_too_large(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'thresholds must be in \\(0, 1\\) range\\. Given: 1\\.0'):\n      head_lib.MultiLabelHead(n_classes=2, thresholds=[0.5, 1.0])\n\n  def test_label_vocabulary_dict(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'label_vocabulary must be a list or tuple\\. '\n        r'Given type: <(type|class) \\'dict\\'>'):\n      head_lib.MultiLabelHead(n_classes=2, label_vocabulary={'foo': 'bar'})\n\n  def test_label_vocabulary_wrong_size(self):\n    with self.assertRaisesRegexp(\n        ValueError,\n        r'Length of label_vocabulary must be n_classes \\(3\\). Given: 2'):\n      head_lib.MultiLabelHead(n_classes=3, label_vocabulary=['foo', 'bar'])\n\n  def test_invalid_loss_reduction(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'Invalid loss_reduction: invalid_loss_reduction'):\n      head_lib.MultiLabelHead(\n          n_classes=3, loss_reduction='invalid_loss_reduction')\n    with self.assertRaisesRegexp(ValueError, r'Invalid loss_reduction: none'):\n      head_lib.MultiLabelHead(\n          n_classes=3, loss_reduction=tf.losses.Reduction.NONE)\n\n  def test_loss_fn_arg_labels_missing(self):\n\n    def _loss_fn(logits):\n      del logits  # Unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: labels\\. '\n        r'Given arguments: \\(\\'logits\\',\\)'):\n      head_lib.MultiLabelHead(n_classes=3, loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_logits_missing(self):\n\n    def _loss_fn(labels):\n      del labels  # unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: logits\\. '\n        r'Given arguments: \\(\\'labels\\',\\)'):\n      head_lib.MultiLabelHead(n_classes=3, loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_features_ok(self):\n\n    def _loss_fn(labels, logits, features):\n      del labels, logits, features  # Unused\n\n    head_lib.MultiLabelHead(n_classes=3, loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_invalid(self):\n\n    def _loss_fn(labels, logits, name=None):\n      del labels, logits, name  # Unused\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'loss_fn has unexpected args: \\[\\'name\\'\\]'):\n      head_lib.MultiLabelHead(n_classes=3, loss_fn=_loss_fn)\n\n  def test_classes_for_class_based_metrics_invalid(self):\n    with self.assertRaisesRegexp(\n        ValueError,\n        r'All classes_for_class_based_metrics must be in range \\[0, 2\\]\\. '\n        r'Given: -1'):\n      head_lib.MultiLabelHead(\n          n_classes=3, classes_for_class_based_metrics=[2, -1])\n\n  def test_classes_for_class_based_metrics_string_invalid(self):\n    with self.assertRaisesRegexp(ValueError, r'\\'z\\' is not in list'):\n      head_lib.MultiLabelHead(\n          n_classes=3,\n          label_vocabulary=['a', 'b', 'c'],\n          classes_for_class_based_metrics=['c', 'z'])\n\n  def test_predict(self):\n    n_classes = 4\n    head = head_lib.MultiLabelHead(n_classes)\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    logits = np.array([[0., 1., 2., -1.], [-1., -2., -3., 1.]],\n                      dtype=np.float32)\n    expected_probabilities = tf.math.sigmoid(logits)\n    expected_export_classes = [[b'0', b'1', b'2', b'3']] * 2\n\n    keys = prediction_keys.PredictionKeys\n    preds = head.predictions(logits,\n                             [keys.LOGITS, keys.PROBABILITIES, keys.CLASSES])\n    self.assertAllClose(logits, self.evaluate(preds[keys.LOGITS]))\n    self.assertAllClose(expected_probabilities,\n                        self.evaluate(preds[keys.PROBABILITIES]))\n    if tf.executing_eagerly():\n      return\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    self.assertItemsEqual(\n        (test_lib._DEFAULT_SERVING_KEY, 'predict', 'classification'),\n        spec.export_outputs.keys())\n    # Assert predictions and export_outputs.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllEqual(expected_export_classes,\n                          predictions[prediction_keys.PredictionKeys.CLASSES])\n      self.assertAllClose(logits,\n                          predictions[prediction_keys.PredictionKeys.LOGITS])\n      self.assertAllClose(\n          expected_probabilities,\n          predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n      self.assertAllClose(\n          expected_probabilities,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].scores))\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].classes))\n\n  def test_weight_should_not_impact_prediction(self):\n    n_classes = 4\n    head = head_lib.MultiLabelHead(n_classes, weight_column='example_weights')\n    self.assertEqual(n_classes, head.logits_dimension)\n\n    logits = np.array([[0., 1., 2., -1.], [-1., -2., -3., 1.]],\n                      dtype=np.float32)\n    expected_probabilities = tf.math.sigmoid(logits)\n    expected_export_classes = [[b'0', b'1', b'2', b'3']] * 2\n    weights_2x1 = [[1.], [2.]]\n    features = {\n        'x': np.array(((42,),), dtype=np.int32),\n        'example_weights': weights_2x1\n    }\n\n    keys = prediction_keys.PredictionKeys\n    preds = head.predictions(logits,\n                             [keys.LOGITS, keys.PROBABILITIES, keys.CLASSES])\n    self.assertAllClose(logits, self.evaluate(preds[keys.LOGITS]))\n    self.assertAllClose(expected_probabilities,\n                        self.evaluate(preds[keys.PROBABILITIES]))\n    if tf.executing_eagerly():\n      return\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert predictions and export_outputs.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      predictions = sess.run(spec.predictions)\n      self.assertAllEqual(expected_export_classes,\n                          predictions[prediction_keys.PredictionKeys.CLASSES])\n      self.assertAllClose(logits,\n                          predictions[prediction_keys.PredictionKeys.LOGITS])\n      self.assertAllClose(\n          expected_probabilities,\n          predictions[prediction_keys.PredictionKeys.PROBABILITIES])\n\n  def test_eval_create_loss(self):\n    \"\"\"Tests head.loss for eval mode.\"\"\"\n    n_classes = 2\n    head = head_lib.MultiLabelHead(n_classes)\n\n    logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = (labels * -log(sigmoid(logits)) +\n    #         (1 - labels) * -log(1 - sigmoid(logits))) / 2\n    expected_training_loss = 0.5 * np.sum(\n        _sigmoid_cross_entropy(labels=labels, logits=logits))\n    actual_training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    self.assertAllClose(expected_training_loss,\n                        self.evaluate(actual_training_loss))\n\n  def test_eval_create_loss_large_logits(self):\n    \"\"\"Tests head.loss for eval mode and large logits.\"\"\"\n    n_classes = 2\n    head = head_lib.MultiLabelHead(n_classes)\n\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # For large logits, this is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits\n    expected_training_loss = 0.5 * np.sum(\n        np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32))\n    actual_training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    self.assertAllClose(\n        expected_training_loss, self.evaluate(actual_training_loss), atol=1e-4)\n\n  def test_eval_create_loss_labels_wrong_shape(self):\n    \"\"\"Tests head.loss for eval mode when labels has the wrong shape.\"\"\"\n    n_classes = 2\n    head = head_lib.MultiLabelHead(n_classes)\n\n    logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)\n    labels_2x1 = np.array([[1], [1]], dtype=np.int64)\n    labels_2 = np.array([1, 1], dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'Expected labels dimension=2'):\n        head.loss(\n            logits=logits,\n            labels=labels_2x1,\n            features=features,\n            mode=ModeKeys.EVAL)\n      with self.assertRaisesRegexp(ValueError, 'Expected labels dimension=2'):\n        head.loss(\n            logits=logits,\n            labels=labels_2,\n            features=features,\n            mode=ModeKeys.EVAL)\n    else:\n      labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.int64)\n      actual_training_loss = head.loss(\n          logits=logits,\n          labels=labels_placeholder,\n          features=features,\n          mode=ModeKeys.EVAL)\n      with self.cached_session():\n        test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n        with self.assertRaisesRegexp(\n            tf.errors.InvalidArgumentError,\n            r'\\[expected_labels_shape: \\] \\[2 2\\] \\[labels_shape: \\] \\[2 1\\]'):\n          actual_training_loss.eval({labels_placeholder: labels_2x1})\n        with self.assertRaisesRegexp(\n            tf.errors.InvalidArgumentError,\n            r'labels shape must be \\[D0, D1, ... DN, 2\\]\\..*'\n            r'\\[Received shape: \\] \\[2\\]'):\n          actual_training_loss.eval({labels_placeholder: labels_2})\n\n  def test_eval_create_loss_loss_fn(self):\n    \"\"\"Tests head.loss for eval mode and custom loss_fn.\"\"\"\n    loss = np.array([[1.], [2.]], dtype=np.float32)\n    logits_input = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels_input = np.array([[1, 0], [1, 1]], dtype=np.int64)\n\n    def _loss_fn(labels, logits):\n      check_labels = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(labels, labels_input)), data=[labels])\n      check_logits = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(logits, logits_input)), data=[logits])\n      with tf.control_dependencies([check_labels, check_logits]):\n        return tf.constant(loss)\n\n    head = head_lib.MultiLabelHead(n_classes=2, loss_fn=_loss_fn)\n\n    actual_training_loss = head.loss(\n        logits=logits_input,\n        labels=labels_input,\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.EVAL)\n    self.assertAllClose(np.sum(loss) / 2., self.evaluate(actual_training_loss))\n\n  def test_eval_create_loss_loss_fn_wrong_shape(self):\n    \"\"\"Tests custom loss_fn that returns Tensor of unexpected shape.\"\"\"\n    loss = np.array([1., 2.], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      del labels, logits  # Unused\n      return tf.constant(loss)\n\n    head = head_lib.MultiLabelHead(n_classes=2, loss_fn=_loss_fn)\n\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'loss_shape'):\n        head.loss(\n            logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n    else:\n      actual_training_loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.EVAL)\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[loss_fn must return Tensor of shape \\[D0, D1, ... DN, 1\\]\\. \\] '\n          r'\\[logits_shape: \\] \\[2 2\\] \\[loss_shape: \\] \\[2\\]'):\n        self.evaluate(actual_training_loss)\n\n  def test_eval_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib.MultiLabelHead(n_classes=2)\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.loss(\n          logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n          labels=None,\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.EVAL)\n\n  def _test_eval(self,\n                 head,\n                 logits,\n                 labels,\n                 expected_loss,\n                 expected_metrics,\n                 features=None,\n                 regularization_losses=None):\n    tol = 1e-3\n    if tf.executing_eagerly():\n      loss = head.loss(\n          labels,\n          logits,\n          features=features or {},\n          mode=ModeKeys.EVAL,\n          regularization_losses=regularization_losses)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n\n      eval_metrics = head.metrics(regularization_losses=regularization_losses)\n      updated_metrics = head.update_metrics(\n          eval_metrics,\n          features or {},\n          logits,\n          labels,\n          regularization_losses=regularization_losses)\n      self.assertItemsEqual(expected_metrics.keys(), updated_metrics.keys())\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics},\n          rtol=tol,\n          atol=tol)\n      return\n\n    spec = head.create_estimator_spec(\n        features=features or {},\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        regularization_losses=regularization_losses,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert spec contains expected tensors.\n    self.assertIsNotNone(spec.loss)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n    # Assert predictions, loss, and metrics.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      loss, _ = sess.run((spec.loss, update_ops))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      # Check results of value ops (in `metrics`).\n      self.assertAllClose(\n          expected_metrics, {k: value_ops[k].eval() for k in value_ops},\n          rtol=tol,\n          atol=tol)\n\n  def test_eval(self):\n    n_classes = 2\n    head = head_lib.MultiLabelHead(n_classes)\n    logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # Sum over examples, divide by batch_size.\n    expected_loss = 0.5 * np.sum(\n        _sigmoid_cross_entropy(labels=labels, logits=logits))\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # Average loss over examples.\n        keys.LOSS_MEAN: expected_loss,\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.3333,\n        keys.AUC_PR: 0.7689,\n    }\n    self._test_eval(\n        head=head,\n        logits=logits,\n        labels=labels,\n        expected_loss=expected_loss,\n        expected_metrics=expected_metrics)\n\n  def test_eval_sparse_labels(self):\n    n_classes = 2\n    head = head_lib.MultiLabelHead(n_classes)\n    logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)\n    # Equivalent to multi_hot = [[1, 0], [1, 1]]\n    labels = tf.sparse.SparseTensor(\n        values=[0, 0, 1], indices=[[0, 0], [1, 0], [1, 1]], dense_shape=[2, 2])\n    labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # Sum over examples, divide by batch_size.\n    expected_loss = 0.5 * np.sum(\n        _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # Average loss over examples.\n        keys.LOSS_MEAN: expected_loss,\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.3333,\n        keys.AUC_PR: 0.7689,\n    }\n    self._test_eval(\n        head=head,\n        logits=logits,\n        labels=labels,\n        expected_loss=expected_loss,\n        expected_metrics=expected_metrics)\n\n  def test_eval_with_regularization_losses(self):\n    n_classes = 2\n    head = head_lib.MultiLabelHead(n_classes)\n    logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = sum(\n    #     labels * -log(sigmoid(logits)) +\n    #     (1 - labels) * -log(1 - sigmoid(logits))) / batch_size\n    expected_unregularized_loss = np.sum(\n        _sigmoid_cross_entropy(labels=labels, logits=logits)) / 2.\n    expected_regularized_loss = (\n        expected_unregularized_loss + expected_regularization_loss)\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_unregularized_loss,\n        keys.LOSS_REGULARIZATION: expected_regularization_loss,\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.3333,\n        keys.AUC_PR: 0.7689,\n    }\n    self._test_eval(\n        head=head,\n        logits=logits,\n        labels=labels,\n        expected_loss=expected_regularized_loss,\n        expected_metrics=expected_metrics,\n        regularization_losses=regularization_losses)\n\n  def test_eval_with_label_vocabulary(self):\n    n_classes = 2\n    head = head_lib.MultiLabelHead(\n        n_classes, label_vocabulary=['class0', 'class1'])\n    logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)\n    # Equivalent to multi_hot = [[1, 0], [1, 1]]\n    labels = tf.sparse.SparseTensor(\n        values=['class0', 'class0', 'class1'],\n        indices=[[0, 0], [1, 0], [1, 1]],\n        dense_shape=[2, 2])\n    labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # Sum over examples, divide by batch_size.\n    expected_loss = 0.5 * np.sum(\n        _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # Average loss over examples.\n        keys.LOSS_MEAN: expected_loss,\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.3333,\n        keys.AUC_PR: 0.7689,\n    }\n    self._test_eval(\n        head=head,\n        logits=logits,\n        labels=labels,\n        expected_loss=expected_loss,\n        expected_metrics=expected_metrics)\n\n  def test_eval_with_label_vocabulary_with_multi_hot_input(self):\n    n_classes = 2\n    head = head_lib.MultiLabelHead(\n        n_classes, label_vocabulary=['class0', 'class1'])\n    logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)\n    labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # Sum over examples, divide by batch_size.\n    expected_loss = 0.5 * np.sum(\n        _sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # Average loss over examples.\n        keys.LOSS_MEAN: expected_loss,\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.3333,\n        keys.AUC_PR: 0.7689,\n    }\n    self._test_eval(\n        head=head,\n        logits=logits,\n        labels=labels_multi_hot,\n        expected_loss=expected_loss,\n        expected_metrics=expected_metrics)\n\n  def test_eval_with_thresholds(self):\n    n_classes = 2\n    thresholds = [0.25, 0.5, 0.75]\n    head = head_lib.MultiLabelHead(n_classes, thresholds=thresholds)\n\n    logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # Sum over examples, divide by batch_size.\n    expected_loss = 0.5 * np.sum(\n        _sigmoid_cross_entropy(labels=labels, logits=logits))\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # Average loss over examples.\n        keys.LOSS_MEAN: expected_loss,\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.3333,\n        keys.AUC_PR: 0.7689,\n        keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4.,\n        keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3.,\n        keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3.,\n        keys.ACCURACY_AT_THRESHOLD % thresholds[1]: 1. / 4.,\n        keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1. / 2.,\n        keys.RECALL_AT_THRESHOLD % thresholds[1]: 1. / 3.,\n        keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 2. / 4.,\n        keys.PRECISION_AT_THRESHOLD % thresholds[2]: 1. / 1.,\n        keys.RECALL_AT_THRESHOLD % thresholds[2]: 1. / 3.,\n    }\n\n    self._test_eval(\n        head=head,\n        logits=logits,\n        labels=labels,\n        expected_loss=expected_loss,\n        expected_metrics=expected_metrics)\n\n  def test_eval_with_classes_for_class_based_metrics(self):\n    head = head_lib.MultiLabelHead(\n        n_classes=2, classes_for_class_based_metrics=[0, 1])\n\n    logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # Sum over examples, divide by batch_size.\n    expected_loss = 0.5 * np.sum(\n        _sigmoid_cross_entropy(labels=labels, logits=logits))\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # Average loss over examples.\n        keys.LOSS_MEAN: expected_loss,\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.3333,\n        keys.AUC_PR: 0.7689,\n        keys.PROBABILITY_MEAN_AT_CLASS % 0:\n            tf.math.reduce_sum(tf.math.sigmoid(logits[:, 0])) / 2.,\n        keys.AUC_AT_CLASS % 0: 0.,\n        keys.AUC_PR_AT_CLASS % 0: 1.,\n        keys.PROBABILITY_MEAN_AT_CLASS % 1:\n            tf.math.reduce_sum(tf.math.sigmoid(logits[:, 1])) / 2.,\n        keys.AUC_AT_CLASS % 1: 1.,\n        keys.AUC_PR_AT_CLASS % 1: 1.,\n    }\n\n    self._test_eval(\n        head=head,\n        logits=logits,\n        labels=labels,\n        expected_loss=expected_loss,\n        expected_metrics=expected_metrics)\n\n  def test_eval_with_classes_for_class_based_metrics_string(self):\n    head = head_lib.MultiLabelHead(\n        n_classes=2,\n        label_vocabulary=['a', 'b'],\n        classes_for_class_based_metrics=['a', 'b'])\n\n    logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)\n    labels = tf.sparse.SparseTensor(\n        values=['a', 'a', 'b'],\n        indices=[[0, 0], [1, 0], [1, 1]],\n        dense_shape=[2, 2])\n    labels_onehot = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # Sum over examples, divide by batch_size.\n    expected_loss = 0.5 * np.sum(\n        _sigmoid_cross_entropy(labels=labels_onehot, logits=logits))\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # Average loss over examples.\n        keys.LOSS_MEAN: expected_loss,\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.3333,\n        keys.AUC_PR: 0.7689,\n        keys.PROBABILITY_MEAN_AT_NAME % 'a':\n            tf.math.reduce_sum(tf.math.sigmoid(logits[:, 0])) / 2.,\n        keys.AUC_AT_NAME % 'a': 0.,\n        keys.AUC_PR_AT_NAME % 'a': 1.,\n        keys.PROBABILITY_MEAN_AT_NAME % 'b':\n            tf.math.reduce_sum(tf.math.sigmoid(logits[:, 1])) / 2.,\n        keys.AUC_AT_NAME % 'b': 1.,\n        keys.AUC_PR_AT_NAME % 'b': 1.,\n    }\n\n    self._test_eval(\n        head=head,\n        logits=logits,\n        labels=labels,\n        expected_loss=expected_loss,\n        expected_metrics=expected_metrics)\n\n  def test_eval_with_weights(self):\n    n_classes = 2\n    head = head_lib.MultiLabelHead(n_classes, weight_column='example_weights')\n\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    features = {\n        'x': np.array([[41], [42]], dtype=np.int32),\n        'example_weights': np.array([[1.], [2.]], dtype=np.float32),\n    }\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # Average over classes, weighted sum over examples, divide by batch_size.\n    # loss = (1 * (10 + 10) / 2 + 2 * (15 + 0) / 2) / 2\n    expected_loss = 12.5\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        # Average loss over weighted examples (denominator is sum(weights)).\n        keys.LOSS_MEAN: expected_loss * (2. / 3.),\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.2000,\n        keys.AUC_PR: 0.7280,\n    }\n    self._test_eval(\n        head=head,\n        logits=logits,\n        labels=labels,\n        expected_loss=expected_loss,\n        expected_metrics=expected_metrics,\n        features=features)\n\n  def test_train_create_loss_large_logits(self):\n    \"\"\"Tests head.create_loss for train mode and large logits.\"\"\"\n    n_classes = 2\n    head = head_lib.MultiLabelHead(n_classes, weight_column='example_weights')\n\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    weights = np.array([[1.], [2.]], dtype=np.float32)\n    features = {\n        'x': np.array(((42,),), dtype=np.int32),\n        'example_weights': weights\n    }\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # For large logits, this is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits\n    # expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]]\n    # expected_weights = [[1.], [2.]]\n    expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertAllClose(\n        expected_training_loss, self.evaluate(training_loss), atol=1e-4)\n\n  def test_train_create_loss_loss_reduction(self):\n    \"\"\"Tests head.create_loss with loss_reduction.\"\"\"\n    n_classes = 2\n    head = head_lib.MultiLabelHead(\n        n_classes,\n        weight_column='example_weights',\n        loss_reduction=tf.losses.Reduction.SUM)\n\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    weights = np.array([[1.], [2.]], dtype=np.float32)\n    # loss = labels * -log(sigmoid(logits)) +\n    #        (1 - labels) * -log(1 - sigmoid(logits))\n    # For large logits, this is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits\n    # expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]]\n    # expected_weights = [[1.], [2.]]\n    expected_training_loss = (1. * (10. + 10.) + 2. * (15. + 0.)) / 2.\n    training_loss = head.loss(\n        logits=logits,\n        labels=labels,\n        features={\n            'x': np.array(((42,),), dtype=np.int32),\n            'example_weights': weights\n        },\n        mode=ModeKeys.TRAIN)\n    self.assertAllClose(\n        expected_training_loss, self.evaluate(training_loss), atol=1e-4)\n\n  def test_train_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib.MultiLabelHead(n_classes=2)\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.loss(\n          logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n          labels=None,\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN)\n\n  def test_train_invalid_indicator_labels(self):\n    head = head_lib.MultiLabelHead(n_classes=2)\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    # The value 2 is outside the allowed range.\n    labels = np.array([[2, 0], [1, 1]], dtype=np.int64)\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(\n          ValueError,\n          r'labels must be an integer indicator Tensor with values in '\n          r'\\[0, 1\\]'):\n        head.loss(\n            logits=logits, labels=labels, features={}, mode=ModeKeys.TRAIN)\n      return\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    with self.cached_session() as sess:\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'labels must be an integer indicator Tensor with values in '\n          r'\\[0, 1\\]'):\n        spec = head.create_estimator_spec(\n            features={},\n            mode=ModeKeys.TRAIN,\n            logits=logits,\n            labels=labels,\n            train_op_fn=_train_op_fn,\n            trainable_variables=[\n                tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n            ])\n        test_lib._initialize_variables(self, spec.scaffold)\n        sess.run(spec.loss)\n\n  def test_train_invalid_sparse_labels(self):\n    head = head_lib.MultiLabelHead(n_classes=2)\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    # The value 2 is outside the allowed range.\n    labels = tf.sparse.SparseTensor(\n        values=[2, 0, 1], indices=[[0, 0], [1, 0], [1, 1]], dense_shape=[2, 2])\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(\n          ValueError,\n          r'labels must be an integer SparseTensor with values in \\[0, 2\\)'):\n        head.loss(\n            logits=logits, labels=labels, features={}, mode=ModeKeys.TRAIN)\n      return\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    with self.cached_session() as sess:\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'labels must be an integer SparseTensor with values in \\[0, 2\\)'):\n        spec = head.create_estimator_spec(\n            features={},\n            mode=ModeKeys.TRAIN,\n            logits=logits,\n            labels=labels,\n            train_op_fn=_train_op_fn,\n            trainable_variables=[\n                tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n            ])\n        test_lib._initialize_variables(self, spec.scaffold)\n        sess.run(spec.loss)\n\n  def _test_train(self, head, logits, labels, expected_loss):\n    tol = 1e-3\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=3)\n      ])\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(\n          self, {metric_keys.MetricKeys.LOSS: expected_loss}, summary_str, tol)\n\n  def test_train(self):\n    head = head_lib.MultiLabelHead(n_classes=2)\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # Average over classes, sum over examples, divide by batch_size.\n    # loss = ((10 + 10) / 2 + (15 + 0) / 2 ) / 2\n    expected_loss = 8.75\n    self._test_train(\n        head=head, logits=logits, labels=labels, expected_loss=expected_loss)\n\n  def test_train_sparse_labels(self):\n    head = head_lib.MultiLabelHead(n_classes=2)\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    # Equivalent to multi_hot = [[1, 0], [1, 1]]\n    labels = tf.sparse.SparseTensor(\n        values=[0, 0, 1], indices=[[0, 0], [1, 0], [1, 1]], dense_shape=[2, 2])\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # Average over classes, sum over examples, divide by batch_size.\n    # loss = ((10 + 10) / 2 + (15 + 0) / 2 ) / 2\n    expected_loss = 8.75\n    self._test_train(\n        head=head, logits=logits, labels=labels, expected_loss=expected_loss)\n\n  def test_train_with_label_vocabulary(self):\n    head = head_lib.MultiLabelHead(\n        n_classes=2, label_vocabulary=['class0', 'class1'])\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    # Equivalent to multi_hot = [[1, 0], [1, 1]]\n    labels = tf.sparse.SparseTensor(\n        values=['class0', 'class0', 'class1'],\n        indices=[[0, 0], [1, 0], [1, 1]],\n        dense_shape=[2, 2])\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # Average over classes, sum over examples, divide by batch_size.\n    # loss = ((10 + 10) / 2 + (15 + 0) / 2 ) / 2\n    expected_loss = 8.75\n    self._test_train(\n        head=head, logits=logits, labels=labels, expected_loss=expected_loss)\n\n  def test_train_with_regularization_losses(self):\n    head = head_lib.MultiLabelHead(n_classes=2)\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    regularization_losses = [1.5, 0.5]\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # Average over classes and over batch and add regularization loss.\n    expected_loss = 35. / 4. + 2.\n    expected_summaries = {\n        metric_keys.MetricKeys.LOSS: expected_loss,\n        metric_keys.MetricKeys.LOSS_REGULARIZATION: 2.,\n    }\n    tol = 1e-3\n    loss = head.loss(\n        logits=logits,\n        labels=labels,\n        features=features,\n        mode=ModeKeys.TRAIN,\n        regularization_losses=regularization_losses)\n    self.assertIsNotNone(loss)\n    self.assertAllClose(expected_loss, self.evaluate(loss), rtol=tol, atol=tol)\n    if tf.executing_eagerly():\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=3)\n      ])\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        regularization_losses=regularization_losses,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(self, expected_summaries, summary_str,\n                                        tol)\n\n  def test_train_with_weights(self):\n    n_classes = 2\n    head = head_lib.MultiLabelHead(n_classes, weight_column='example_weights')\n\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    features = {\n        'x': np.array([[41], [42]], dtype=np.int32),\n        'example_weights': np.array([[1.], [2.]], dtype=np.float32),\n    }\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # Average over classes, weighted sum over examples, divide by batch_size.\n    # loss = (1 * (10 + 10) / 2 + 2 * (15 + 0) / 2) / 2\n    expected_loss = 12.5\n    tol = 1e-3\n\n    loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertIsNotNone(loss)\n    self.assertAllClose(expected_loss, self.evaluate(loss), rtol=tol, atol=tol)\n    if tf.executing_eagerly():\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=3)\n      ])\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    self.assertIsNotNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      loss, train_result, summary_str = sess.run(\n          (spec.loss, spec.train_op, spec.scaffold.summary_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),\n          train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str, tol)\n\n  def test_multi_dim_weighted_train_create_loss(self):\n    \"\"\"Logits and labels of shape [2, 2, 3], weights [2, 2].\"\"\"\n    head = head_lib.MultiLabelHead(n_classes=3, weight_column='weights')\n\n    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],\n                       [[-12., 12., -12.], [12., -12., 12.]]],\n                      dtype=np.float32)\n    labels = np.array([[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]],\n                      dtype=np.int64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    # unreduced_loss =\n    #     [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3\n    #   = [[20/3, 10/3], [4, 8]]\n    # expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]]\n    # weights are reshaped to [2, 2, 1] to match logits.\n    # expected_weights = [[[1.], [1.5]], [[2.], [2.5]]]\n    # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167\n    expected_training_loss = 9.9167\n    training_loss = head.loss(\n        logits=logits,\n        labels=labels,\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN)\n    atol = 1.e-3\n    self.assertAllClose(\n        expected_training_loss, self.evaluate(training_loss), atol=atol)\n\n  def test_multi_dim_weighted_train(self):\n    \"\"\"Logits and labels of shape [2, 2, 3], weights [2, 2].\"\"\"\n    head = head_lib.MultiLabelHead(n_classes=3, weight_column='weights')\n\n    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],\n                       [[-12., 12., -12.], [12., -12., 12.]]],\n                      dtype=np.float32)\n    labels = np.array([[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]],\n                      dtype=np.int64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3\n    #      = [[20/3, 10/3], [4, 8]]\n    # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167\n    expected_loss = 9.9167\n    atol = 1.e-3\n\n    loss = head.loss(\n        logits=logits,\n        labels=labels,\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN)\n    self.assertIsNotNone(loss)\n    self.assertAllClose(expected_loss, self.evaluate(loss), atol=atol)\n    if tf.executing_eagerly():\n      return\n\n    expected_train_result = 'my_train_op'\n\n    def _train_op_fn(loss):\n      return tf.strings.join([\n          tf.constant(expected_train_result),\n          tf.strings.as_string(loss, precision=3)\n      ])\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss, atol=atol)\n      self.assertEqual(\n          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),\n          train_result)\n\n  def test_multi_dim_weights_wrong_inner_dim(self):\n    \"\"\"Logits and labels of shape [2, 2, 3], weights [2, 1].\"\"\"\n    head = head_lib.MultiLabelHead(n_classes=3, weight_column='weights')\n\n    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],\n                       [[-12., 12., -12.], [12., -12., 12.]]],\n                      dtype=np.float32)\n    labels = np.array([[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]],\n                      dtype=np.int64)\n    weights = np.array([[1.], [2.]], dtype=np.float32)\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'weights shape'):\n        head.loss(\n            logits=logits,\n            labels=labels,\n            features={'weights': weights},\n            mode=ModeKeys.TRAIN)\n      return\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\] \\[2 2 3\\] \\[weights_shape: \\] \\[2 1\\]'):\n        spec.loss.eval()\n\n  def test_multi_dim_weights_wrong_outer_dim(self):\n    \"\"\"Logits and labels of shape [2, 2, 3], weights [2, 2, 3].\"\"\"\n    head = head_lib.MultiLabelHead(n_classes=3, weight_column='weights')\n\n    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],\n                       [[-12., 12., -12.], [12., -12., 12.]]],\n                      dtype=np.float32)\n    labels = np.array([[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]],\n                      dtype=np.int64)\n    weights = np.array(\n        [[[1., 1., 1.], [1.5, 1.5, 1.5]], [[2., 2., 2.], [2.5, 2.5, 2.5]]],\n        dtype=np.float32)\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'weights shape'):\n        head.loss(\n            logits=logits,\n            labels=labels,\n            features={'weights': weights},\n            mode=ModeKeys.TRAIN)\n      return\n\n    weights_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features={'weights': weights_placeholder},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\] \\[2 2 3\\] \\[weights_shape: \\] \\[2 2 3\\]'):\n        spec.loss.eval({weights_placeholder: weights})\n\n  def test_multi_dim_weighted_eval(self):\n    \"\"\"Logits and labels of shape [2, 2, 3], weights [2, 2].\"\"\"\n    head = head_lib.MultiLabelHead(n_classes=3, weight_column='weights')\n\n    logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],\n                       [[-12., 12., -12.], [12., -12., 12.]]],\n                      dtype=np.float32)\n    labels = np.array([[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]],\n                      dtype=np.int64)\n    weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)\n    # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3\n    #      = [[20/3, 10/3], [4, 8]]\n    # loss = (1*20/3 + 1.5*10/3 + 2*4 + 2.5*8) / 4 = 9.9167\n    expected_loss = 9.9167\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_loss * (4. / np.sum(weights)),\n        # auc and auc_pr cannot be reliably calculated for only 4 samples, but\n        # this assert tests that the algorithm remains consistent.\n        keys.AUC: 0.4977,\n        keys.AUC_PR: 0.5461,\n    }\n    self._test_eval(\n        head=head,\n        features={'weights': weights},\n        logits=logits,\n        labels=labels,\n        expected_loss=expected_loss,\n        expected_metrics=expected_metrics)\n\n\n@test_util.deprecated_graph_mode_only\nclass MultiLabelHeadForEstimator(tf.test.TestCase):\n  \"\"\"Tests for create_estimator_spec running in Graph mode only.\"\"\"\n\n  def test_invalid_trainable_variables(self):\n    head = head_lib.MultiLabelHead(n_classes=2)\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params\n        return [\n            tf.strings.join([\n                tf.constant('my_train_op'),\n                tf.strings.as_string(loss, precision=2)\n            ])\n        ]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'trainable_variables cannot be None'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n          labels=np.array([[1, 0], [1, 1]], dtype=np.int64),\n          optimizer=_Optimizer('my_optimizer'),\n          trainable_variables=None)\n    with self.assertRaisesRegexp(\n        ValueError, r'trainable_variables should be a list or a tuple'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n          labels=np.array([[1, 0], [1, 1]], dtype=np.int64),\n          optimizer=_Optimizer('my_optimizer'),\n          trainable_variables={\n              'var_list': [tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)]\n          })\n\n  def test_train_with_optimizer(self):\n    head = head_lib.MultiLabelHead(n_classes=2)\n    logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)\n    labels = np.array([[1, 0], [1, 1]], dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    # For large logits, sigmoid cross entropy loss is approximated as:\n    # loss = labels * (logits < 0) * (-logits) +\n    #        (1 - labels) * (logits > 0) * logits =>\n    # expected_unweighted_loss = [[10., 10.], [15., 0.]]\n    # Average over classes, sum over examples, divide by batch_size.\n    # loss = ((10 + 10) / 2 + (15 + 0) / 2 ) / 2\n    expected_loss = 8.75\n    expected_train_result = 'my_train_op'\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params\n        return [\n            tf.strings.join([\n                tf.constant(expected_train_result),\n                tf.strings.as_string(loss, precision=3)\n            ])\n        ]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        optimizer=_Optimizer('my_optimizer'),\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    tol = 1e-3\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)\n      self.assertEqual(\n          six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),\n          train_result)\n\n  def test_predict_with_label_vocabulary(self):\n    n_classes = 4\n    head = head_lib.MultiLabelHead(\n        n_classes, label_vocabulary=['foo', 'bar', 'foobar', 'barfoo'])\n\n    logits = np.array([[0., 1., 2., -1.], [-1., -2., -3., 1.]],\n                      dtype=np.float32)\n    expected_export_classes = [[b'foo', b'bar', b'foobar', b'barfoo']] * 2\n\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      predictions = sess.run(spec.predictions)\n      self.assertAllEqual(expected_export_classes,\n                          predictions[prediction_keys.PredictionKeys.CLASSES])\n      self.assertAllClose(logits,\n                          predictions[prediction_keys.PredictionKeys.LOGITS])\n      self.assertAllEqual(\n          expected_export_classes,\n          sess.run(spec.export_outputs[test_lib._DEFAULT_SERVING_KEY].classes))\n\n  def test_train_with_update_ops(self):\n    with tf.Graph().as_default():\n      w = tf.Variable(1)\n      update_op = w.assign_add(1)\n\n      t = tf.Variable('')\n      expected_train_result = b'my_train_op'\n\n      def _train_op_fn(loss):\n        del loss\n        return t.assign(expected_train_result)\n\n      head = head_lib.MultiLabelHead(n_classes=2)\n      spec = head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),\n          labels=np.array([[1, 0], [1, 1]], dtype=np.int64),\n          train_op_fn=_train_op_fn,\n          update_ops=[update_op],\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n\n      with self.cached_session() as sess:\n        test_lib._initialize_variables(self, spec.scaffold)\n        sess.run(spec.train_op)\n        w_value, t_value = sess.run([w, t])\n        self.assertEqual(2, w_value)\n        self.assertEqual(expected_train_result, t_value)\n\n  def test_lookup_tables_in_graph(self):\n    n_classes = 2\n    head = head_lib.MultiLabelHead(\n        n_classes=n_classes, label_vocabulary=['class0', 'class1'])\n\n    feature_columns = [tf.feature_column.numeric_column('x')]\n    # Create dnn estimator.\n    est = dnn.DNNEstimatorV2(\n        head=head, hidden_units=(2, 2), feature_columns=feature_columns)\n\n    def input_fn():\n      return ({\n          'x': np.array(((42,), (43,),), dtype=np.int32)\n      }, np.array([[1, 0], [1, 1]], dtype=np.int64))\n\n    # Train.\n    num_steps = 1\n    est.train(input_fn, steps=num_steps)\n    # Eval.\n    eval_results = est.evaluate(input_fn, steps=num_steps)\n    self.assertEqual(num_steps,\n                     eval_results[tf.compat.v1.GraphKeys.GLOBAL_STEP])\n    self.assertIn('loss', six.iterkeys(eval_results))\n    # Predict.\n    est.predict(input_fn)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/regression_head.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Regression head.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\n\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.head import base_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v2\n\n\n@estimator_export('estimator.RegressionHead')\nclass RegressionHead(base_head.Head):\n  \"\"\"Creates a `Head` for regression using the `mean_squared_error` loss.\n\n  The loss is the weighted sum over all input dimensions. Namely, if the input\n  labels have shape `[batch_size, label_dimension]`, the loss is the weighted\n  sum over both `batch_size` and `label_dimension`.\n\n  The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`.\n  In many applications, the shape is `[batch_size, label_dimension]`.\n\n  The `labels` shape must match `logits`, namely\n  `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape\n  `[D0, D1, ... DN]` is also supported.\n\n  If `weight_column` is specified, weights must be of shape\n  `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or\n  `[D0, D1, ... DN, label_dimension]`.\n\n  Supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or\n  `(labels, logits, features, loss_reduction)` as arguments and returns\n  unreduced loss with shape `[D0, D1, ... DN, label_dimension]`.\n\n  Also supports custom `inverse_link_fn`, also known as 'mean function'.\n  `inverse_link_fn` is only used in `PREDICT` mode. It takes `logits` as\n  argument and returns predicted values. This function is the inverse of the\n  link function defined in\n  https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function\n  Namely, for poisson regression, set `inverse_link_fn=tf.exp`.\n\n  Usage:\n\n  >>> head = tf.estimator.RegressionHead()\n  >>> logits = np.array(((45,), (41,),), dtype=np.float32)\n  >>> labels = np.array(((43,), (44,),), dtype=np.int32)\n  >>> features = {'x': np.array(((42,),), dtype=np.float32)}\n  >>> # expected_loss = weighted_loss / batch_size\n  >>> #               = (43-45)^2 + (44-41)^2 / 2 = 6.50\n  >>> loss = head.loss(labels, logits, features=features)\n  >>> print('{:.2f}'.format(loss.numpy()))\n  6.50\n  >>> eval_metrics = head.metrics()\n  >>> updated_metrics = head.update_metrics(\n  ...   eval_metrics, features, logits, labels)\n  >>> for k in sorted(updated_metrics):\n  ...  print('{} : {:.2f}'.format(k, updated_metrics[k].result().numpy()))\n    average_loss : 6.50\n    label/mean : 43.50\n    prediction/mean : 43.00\n  >>> preds = head.predictions(logits)\n  >>> print(preds['predictions'])\n  tf.Tensor(\n    [[45.]\n     [41.]], shape=(2, 1), dtype=float32)\n\n  Usage with a canned estimator:\n\n  ```python\n  my_head = tf.estimator.RegressionHead()\n  my_estimator = tf.estimator.DNNEstimator(\n      head=my_head,\n      hidden_units=...,\n      feature_columns=...)\n  ```\n\n  It can also be used with a custom `model_fn`. Example:\n\n  ```python\n  def _my_model_fn(features, labels, mode):\n    my_head = tf.estimator.RegressionHead()\n    logits = tf_keras.Model(...)(features)\n\n    return my_head.create_estimator_spec(\n        features=features,\n        mode=mode,\n        labels=labels,\n        optimizer=tf_keras.optimizers.Adagrad(lr=0.1),\n        logits=logits)\n\n  my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)\n  ```\n\n  Args:\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example.\n    label_dimension: Number of regression labels per example. This is the size\n      of the last dimension of the labels `Tensor` (typically, this has shape\n      `[batch_size, label_dimension]`).\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Decides how to\n      reduce training loss over batch and label dimension. Defaults to\n      `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by\n      `batch_size * label_dimension`.\n    loss_fn: Optional loss function. Defaults to `mean_squared_error`.\n    inverse_link_fn: Optional inverse link function, also known as 'mean\n      function'. Defaults to identity.\n    name: name of the head. If provided, summary and metrics keys will be\n      suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops.\n  \"\"\"\n\n  def __init__(self,\n               label_dimension=1,\n               weight_column=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               loss_fn=None,\n               inverse_link_fn=None,\n               name=None):\n    if label_dimension < 1:\n      raise ValueError('Invalid label_dimension {}.'.format(label_dimension))\n    base_head.validate_loss_reduction(loss_reduction)\n    if loss_fn:\n      base_head.validate_loss_fn_args(loss_fn)\n    self._logits_dimension = label_dimension\n    self._weight_column = weight_column\n    self._loss_reduction = loss_reduction\n    self._loss_fn = loss_fn\n    self._inverse_link_fn = inverse_link_fn\n    self._name = name\n    # Metric keys.\n    keys = metric_keys.MetricKeys\n    self._loss_mean_key = self._summary_key(keys.LOSS_MEAN)\n    self._prediction_mean_key = self._summary_key(keys.PREDICTION_MEAN)\n    self._label_mean_key = self._summary_key(keys.LABEL_MEAN)\n    self._loss_regularization_key = self._summary_key(keys.LOSS_REGULARIZATION)\n\n  @property\n  def name(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._name\n\n  @property\n  def logits_dimension(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._logits_dimension\n\n  @property\n  def loss_reduction(self):\n    \"\"\"See `base_head.Head` for details.\"\"\"\n    return self._loss_reduction\n\n  def _processed_labels(self, logits, labels):\n    labels = base_head.check_dense_labels_match_logits_and_reshape(\n        labels=labels,\n        logits=logits,\n        expected_labels_dimension=self._logits_dimension)\n    labels = tf.cast(labels, dtype=tf.dtypes.float32)\n    return labels\n\n  def _unweighted_loss_and_weights(self, logits, labels, features):\n    \"\"\"Computes unweighted loss and weights.\"\"\"\n    if self._loss_fn:\n      unweighted_loss = base_head.call_loss_fn(\n          loss_fn=self._loss_fn,\n          labels=labels,\n          logits=logits,\n          features=features,\n          expected_loss_dim=self._logits_dimension)\n    else:\n      unweighted_loss = tf.compat.v1.losses.mean_squared_error(\n          labels=labels,\n          predictions=logits,\n          reduction=tf.compat.v1.losses.Reduction.NONE)\n    weights = base_head.get_weights_and_check_match_logits(\n        features=features,\n        weight_column=self._weight_column,\n        logits=logits,\n        allow_per_logit_weights=True)\n    return unweighted_loss, weights\n\n  def loss(self,\n           labels,\n           logits,\n           features=None,\n           mode=None,\n           regularization_losses=None):\n    \"\"\"Return predictions based on keys. See `base_head.Head` for details.\"\"\"\n    del mode  # Unused for this head.\n    with ops.name_scope(\n        'losses', values=(logits, labels, regularization_losses, features)):\n      logits = base_head.check_logits_final_dim(logits, self._logits_dimension)\n      labels = self._processed_labels(logits, labels)\n      unweighted_loss, weights = self._unweighted_loss_and_weights(\n          logits, labels, features)\n      training_loss = tf_keras_v2.__internal__.losses.compute_weighted_loss(\n          unweighted_loss,\n          sample_weight=weights,\n          reduction=self._loss_reduction)\n      regularization_loss = tf.math.add_n(\n          regularization_losses) if regularization_losses is not None else None\n      regularized_training_loss = (\n          training_loss + regularization_loss\n          if regularization_loss is not None else training_loss)\n    return regularized_training_loss\n\n  def predictions(self, logits):\n    \"\"\"Return predictions based on keys.\n\n    See `base_head.Head` for details.\n\n    Args:\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.\n        For many applications, the shape is `[batch_size, logits_dimension]`.\n\n    Returns:\n      A dict of predictions.\n    \"\"\"\n    logits = base_head.check_logits_final_dim(logits, self._logits_dimension)\n    pred_keys = prediction_keys.PredictionKeys\n    with ops.name_scope('predictions', values=(logits,)):\n      if self._inverse_link_fn:\n        predicted_value = self._inverse_link_fn(logits)\n        predictions = {\n            pred_keys.PREDICTIONS: predicted_value,\n            pred_keys.LOGITS: logits,\n        }\n      else:\n        predicted_value = logits\n        predictions = {pred_keys.PREDICTIONS: predicted_value}\n    return predictions\n\n  def metrics(self, regularization_losses=None):\n    \"\"\"Creates metrics. See `base_head.Head` for details.\"\"\"\n    with ops.name_scope('metrics', values=(regularization_losses,)):\n      keys = metric_keys.MetricKeys\n      eval_metrics = {}\n      eval_metrics[self._loss_mean_key] = tf_keras.metrics.Mean(\n          name=keys.LOSS_MEAN)\n      eval_metrics[self._prediction_mean_key] = tf_keras.metrics.Mean(\n          name=keys.PREDICTION_MEAN)\n      eval_metrics[self._label_mean_key] = tf_keras.metrics.Mean(\n          name=keys.LABEL_MEAN)\n\n      if regularization_losses is not None:\n        eval_metrics[self._loss_regularization_key] = tf_keras.metrics.Mean(\n            name=keys.LOSS_REGULARIZATION)\n    return eval_metrics\n\n  def update_metrics(self,\n                     eval_metrics,\n                     features,\n                     logits,\n                     labels,\n                     regularization_losses=None):\n    \"\"\"Updates eval metrics. See `base_head.Head` for details.\"\"\"\n    # Compute predictions.\n    predictions = self.predictions(logits)\n    predicted_value = predictions[prediction_keys.PredictionKeys.PREDICTIONS]\n    logits = base_head.check_logits_final_dim(logits, self.logits_dimension)\n    label_ids = self._processed_labels(logits, labels)\n    unweighted_loss, weights = self._unweighted_loss_and_weights(\n        logits, label_ids, features)\n\n    # Update metrics.\n    eval_metrics[self._loss_mean_key].update_state(\n        values=unweighted_loss, sample_weight=weights)\n    eval_metrics[self._label_mean_key].update_state(\n        values=labels, sample_weight=weights)\n    base_head.update_metric_with_broadcast_weights(\n        eval_metrics[self._prediction_mean_key], predicted_value, weights)\n    if regularization_losses is not None:\n      regularization_loss = tf.math.add_n(regularization_losses)\n      eval_metrics[self._loss_regularization_key].update_state(\n          values=regularization_loss)\n    return eval_metrics\n\n  def _create_tpu_estimator_spec(self,\n                                 features,\n                                 mode,\n                                 logits,\n                                 labels=None,\n                                 optimizer=None,\n                                 trainable_variables=None,\n                                 train_op_fn=None,\n                                 update_ops=None,\n                                 regularization_losses=None):\n    \"\"\"Returns an `EstimatorSpec`.\n\n    Args:\n      features: Input `dict` mapping string feature names to `Tensor` or\n        `SparseTensor` objects containing the values for that feature in a\n        minibatch. Often to be used to fetch example-weight tensor.\n      mode: Estimator's `ModeKeys`.\n      logits: logits `Tensor` with shape `[D0, D1, ... DN, logits_dimension]`.\n        For many applications, the shape is `[batch_size, logits_dimension]`.\n      labels: Labels `Tensor` with shape matching `logits`, namely `[D0, D1, ...\n        DN, logits_dimension]`. When `logits_dimension=1`, shape `[D0, D1, ...\n        DN]` is also supported. `labels` is a required argument when `mode`\n        equals `TRAIN` or `EVAL`.\n      optimizer: An `tf_keras.optimizers.Optimizer` instance to optimize the\n        loss in TRAIN mode. Namely, sets `train_op = optimizer.get_updates(loss,\n        trainable_variables)`, which updates variables to minimize `loss`.\n      trainable_variables: A list or tuple of `Variable` objects to update to\n        minimize `loss`. In Tensorflow 1.x, by default these are the list of\n        variables collected in the graph under the key\n        `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have\n        collections and GraphKeys, trainable_variables need to be passed\n        explicitly here.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns\n        `train_op`. Used if `optimizer` is `None`.\n      update_ops: A list or tuple of update ops to be run at training time. For\n        example, layers such as BatchNormalization create mean and variance\n        update ops that need to be run at training time. In Tensorflow 1.x,\n        these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x\n        doesn't have collections, update_ops need to be passed explicitly here.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses. These losses are\n        usually expressed as a batch average, so for best results users need to\n        set `loss_reduction=SUM_OVER_BATCH_SIZE` when creating the head to avoid\n        scaling errors.\n\n    Returns:\n      A `model_fn._TPUEstimatorSpec` instance.\n\n    Raises:\n      ValueError: If both `train_op_fn` and `optimizer` are `None` in TRAIN\n        mode, or if both are set.\n    \"\"\"\n    with ops.name_scope(self._name, 'head'):\n      # Predict.\n      predictions = self.predictions(logits)\n      if mode == ModeKeys.PREDICT:\n        keys = prediction_keys.PredictionKeys\n        regression_output = export_output.RegressionOutput(\n            value=predictions[keys.PREDICTIONS])\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs={\n                base_head.DEFAULT_SERVING_KEY: regression_output,\n                base_head.REGRESS_SERVING_KEY: regression_output,\n                base_head.PREDICT_SERVING_KEY: export_output.PredictOutput(\n                    predictions)\n            })\n      regularized_training_loss = self.loss(\n          logits=logits,\n          labels=labels,\n          features=features,\n          mode=mode,\n          regularization_losses=regularization_losses)\n      # Eval.\n      if mode == ModeKeys.EVAL:\n        eval_metrics = self.metrics(regularization_losses=regularization_losses)\n        return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n            mode=ModeKeys.EVAL,\n            predictions=predictions,\n            loss=regularized_training_loss,\n            eval_metrics=base_head.create_eval_metrics_tuple(\n                self.update_metrics, {\n                    'eval_metrics': eval_metrics,\n                    'features': features,\n                    'logits': logits,\n                    'labels': labels,\n                    'regularization_losses': regularization_losses\n                }))\n      # Train.\n      train_op = base_head.create_estimator_spec_train_op(\n          head_name=self._name,\n          optimizer=optimizer,\n          train_op_fn=train_op_fn,\n          update_ops=update_ops,\n          trainable_variables=trainable_variables,\n          regularized_training_loss=regularized_training_loss,\n          loss_reduction=self._loss_reduction)\n    # Create summary.\n    base_head.create_estimator_spec_summary(\n        regularized_training_loss=regularized_training_loss,\n        regularization_losses=regularization_losses,\n        summary_key_fn=self._summary_key)\n    return model_fn._TPUEstimatorSpec(  # pylint: disable=protected-access\n        mode=ModeKeys.TRAIN,\n        predictions=predictions,\n        loss=regularized_training_loss,\n        train_op=train_op)\n\n\n@estimator_export('estimator.PoissonRegressionHead')\nclass PoissonRegressionHead(RegressionHead):\n  \"\"\"Creates a `Head` for poisson regression using `tf.nn.log_poisson_loss`.\n\n  The loss is the weighted sum over all input dimensions. Namely, if the input\n  labels have shape `[batch_size, label_dimension]`, the loss is the weighted\n  sum over both `batch_size` and `label_dimension`.\n\n  The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`.\n  In many applications, the shape is `[batch_size, label_dimension]`.\n\n  The `labels` shape must match `logits`, namely\n  `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape\n  `[D0, D1, ... DN]` is also supported.\n\n  If `weight_column` is specified, weights must be of shape\n  `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or\n  `[D0, D1, ... DN, label_dimension]`.\n\n  This is implemented as a generalized linear model, see\n  https://en.wikipedia.org/wiki/Generalized_linear_model.\n\n  The head can be used with a canned estimator. Example:\n\n  ```python\n  my_head = tf.estimator.PoissonRegressionHead()\n  my_estimator = tf.estimator.DNNEstimator(\n      head=my_head,\n      hidden_units=...,\n      feature_columns=...)\n  ```\n\n  It can also be used with a custom `model_fn`. Example:\n\n  ```python\n  def _my_model_fn(features, labels, mode):\n    my_head = tf.estimator.PoissonRegressionHead()\n    logits = tf_keras.Model(...)(features)\n\n    return my_head.create_estimator_spec(\n        features=features,\n        mode=mode,\n        labels=labels,\n        optimizer=tf_keras.optimizers.Adagrad(lr=0.1),\n        logits=logits)\n\n  my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)\n  ```\n\n  Args:\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example.\n    label_dimension: Number of regression labels per example. This is the size\n      of the last dimension of the labels `Tensor` (typically, this has shape\n      `[batch_size, label_dimension]`).\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Decides how to\n      reduce training loss over batch and label dimension. Defaults to\n      `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by `batch\n      size * label_dimension`.\n    compute_full_loss: Whether to include the constant `log(z!)` term in\n      computing the poisson loss. See `tf.nn.log_poisson_loss` for the full\n      documentation.\n    name: name of the head. If provided, summary and metrics keys will be\n      suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops.\n  \"\"\"\n\n  def __init__(self,\n               label_dimension=1,\n               weight_column=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               compute_full_loss=True,\n               name=None):\n    self._compute_full_loss = compute_full_loss\n    super(PoissonRegressionHead, self).__init__(\n        label_dimension=label_dimension,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction,\n        loss_fn=self._poisson_loss,\n        inverse_link_fn=tf.math.exp,\n        name=name)\n\n  def _poisson_loss(self, labels, logits):\n    return tf.nn.log_poisson_loss(\n        targets=labels,\n        log_input=logits,\n        compute_full_loss=self._compute_full_loss)\n\n\n@estimator_export('estimator.LogisticRegressionHead')\nclass LogisticRegressionHead(RegressionHead):\n  \"\"\"Creates a `Head` for logistic regression.\n\n  Uses `sigmoid_cross_entropy_with_logits` loss, which is the same as\n  `BinaryClassHead`. The differences compared to `BinaryClassHead` are:\n\n  * Does not support `label_vocabulary`. Instead, labels must be float in the\n    range [0, 1].\n  * Does not calculate some metrics that do not make sense, such as AUC.\n  * In `PREDICT` mode, only returns logits and predictions\n    (`=tf.sigmoid(logits)`), whereas `BinaryClassHead` also returns\n    probabilities, classes, and class_ids.\n  * Export output defaults to `RegressionOutput`, whereas `BinaryClassHead`\n    defaults to `PredictOutput`.\n\n  The head expects `logits` with shape `[D0, D1, ... DN, 1]`.\n  In many applications, the shape is `[batch_size, 1]`.\n\n  The `labels` shape must match `logits`, namely\n  `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`.\n\n  If `weight_column` is specified, weights must be of shape\n  `[D0, D1, ... DN]` or `[D0, D1, ... DN, 1]`.\n\n  This is implemented as a generalized linear model, see\n  https://en.wikipedia.org/wiki/Generalized_linear_model.\n\n  The head can be used with a canned estimator. Example:\n\n  ```python\n  my_head = tf.estimator.LogisticRegressionHead()\n  my_estimator = tf.estimator.DNNEstimator(\n      head=my_head,\n      hidden_units=...,\n      feature_columns=...)\n  ```\n\n  It can also be used with a custom `model_fn`. Example:\n\n  ```python\n  def _my_model_fn(features, labels, mode):\n    my_head = tf.estimator.LogisticRegressionHead()\n    logits = tf_keras.Model(...)(features)\n\n    return my_head.create_estimator_spec(\n        features=features,\n        mode=mode,\n        labels=labels,\n        optimizer=tf_keras.optimizers.Adagrad(lr=0.1),\n        logits=logits)\n\n  my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)\n  ```\n\n  Args:\n    weight_column: A string or a `NumericColumn` created by\n      `tf.feature_column.numeric_column` defining feature column representing\n      weights. It is used to down weight or boost examples during training. It\n      will be multiplied by the loss of the example.\n    loss_reduction: One of `tf.losses.Reduction` except `NONE`. Decides how to\n      reduce training loss over batch and label dimension. Defaults to\n      `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by `batch\n      size * label_dimension`.\n    name: name of the head. If provided, summary and metrics keys will be\n      suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops.\n  \"\"\"\n\n  def _logistic_loss(self, labels, logits):\n    labels = base_head.check_label_range(\n        labels, n_classes=2, message='Labels must be in range [0, 1]')\n    return tf.compat.v1.nn.sigmoid_cross_entropy_with_logits(\n        labels=labels, logits=logits)\n\n  def __init__(self,\n               weight_column=None,\n               loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE,\n               name=None):\n    super(LogisticRegressionHead, self).__init__(\n        label_dimension=1,\n        weight_column=weight_column,\n        loss_reduction=loss_reduction,\n        loss_fn=self._logistic_loss,\n        inverse_link_fn=tf.math.sigmoid,\n        name=name)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/regression_head_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for regression_head.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.head import head_utils as test_lib\nfrom tensorflow_estimator.python.estimator.head import regression_head as head_lib\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass RegressionHead(tf.test.TestCase):\n\n  def test_invalid_label_dimension(self):\n    with self.assertRaisesRegexp(ValueError, r'Invalid label_dimension'):\n      head_lib.RegressionHead(label_dimension=-1)\n    with self.assertRaisesRegexp(ValueError, r'Invalid label_dimension'):\n      head_lib.RegressionHead(label_dimension=0)\n\n  def test_invalid_loss_reduction(self):\n    with self.assertRaisesRegexp(\n        ValueError, r'Invalid loss_reduction: invalid_loss_reduction'):\n      head_lib.RegressionHead(loss_reduction='invalid_loss_reduction')\n    with self.assertRaisesRegexp(ValueError, r'Invalid loss_reduction: none'):\n      head_lib.RegressionHead(loss_reduction=tf.losses.Reduction.NONE)\n\n  def test_loss_fn_arg_labels_missing(self):\n\n    def _loss_fn(logits):\n      del logits  # Unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: labels\\. '\n        r'Given arguments: \\(\\'logits\\',\\)'):\n      head_lib.RegressionHead(loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_logits_missing(self):\n\n    def _loss_fn(labels):\n      del labels  # unused\n\n    with self.assertRaisesRegexp(\n        ValueError, r'loss_fn must contain argument: logits\\. '\n        r'Given arguments: \\(\\'labels\\',\\)'):\n      head_lib.RegressionHead(loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_features_ok(self):\n\n    def _loss_fn(labels, logits, features):\n      del labels, logits, features  # Unused\n      head_lib.RegressionHead(loss_fn=_loss_fn)\n\n  def test_loss_fn_arg_invalid(self):\n\n    def _loss_fn(labels, logits, name=None):\n      del labels, logits, name  # Unused\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'loss_fn has unexpected args: \\[\\'name\\'\\]'):\n      head_lib.RegressionHead(loss_fn=_loss_fn)\n\n  def test_invalid_logits(self):\n    \"\"\"Label dimension is 3, logits shape [1, 2, 1].\"\"\"\n    head = head_lib.RegressionHead(label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n    logits_1d = np.array(((45.,), (41.,),))\n\n    # Static shape.\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      pred = head.predictions(logits_1d)\n      self.evaluate(pred[prediction_keys.PredictionKeys.PREDICTIONS])\n    if tf.executing_eagerly():\n      return\n\n    # Dynamic shape only works in Graph mode.\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42.,),))},\n        mode=ModeKeys.PREDICT,\n        logits=logits_placeholder,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.predictions[prediction_keys.PredictionKeys.PREDICTIONS].eval(\n            {logits_placeholder: logits_1d})\n\n  def test_incompatible_labels_eval(self):\n    head = head_lib.RegressionHead(label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n    values_3d = np.array(((45., 46., 47.), (41., 42., 43.),))\n    values_1d = np.array(((43.,), (44.,),))\n\n    # Static shape.\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):\n        head.loss(\n            logits=values_3d,\n            labels=values_1d,\n            features={'x': values_1d},\n            mode=ModeKeys.EVAL)\n      return\n\n    # Dynamic shape only works in Graph mode.\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      head.create_estimator_spec(\n          features={'x': values_3d},\n          labels=values_3d,\n          mode=ModeKeys.EVAL,\n          logits=values_1d,\n          train_op_fn=None,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': values_1d},\n        mode=ModeKeys.EVAL,\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.loss.eval({\n            labels_placeholder: values_3d,\n            logits_placeholder: values_1d\n        })\n    regularized_training_loss = head.loss(\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        features={'x': values_1d},\n        mode=ModeKeys.EVAL)\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 3\\] \\[labels_shape: \\] \\[2 1\\]'):\n        regularized_training_loss.eval({\n            labels_placeholder: values_1d,\n            logits_placeholder: values_3d\n        })\n\n  def test_incompatible_labels_train(self):\n    head = head_lib.RegressionHead(label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n    values_3d = np.array(((45., 46., 47.), (41., 42., 43.),))  # shape [2, 3]\n    values_1d = np.array(((43.,), (44.,),))  # shape [2, 1]\n\n    # Static shape.\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'Mismatched label shape'):\n        head.loss(\n            logits=values_3d,\n            labels=values_1d,\n            features={'x': values_1d},\n            mode=ModeKeys.TRAIN)\n      return\n\n    # Dynamic shape only works in Graph mode.\n    with self.assertRaisesRegexp(ValueError, 'logits shape'):\n      head.create_estimator_spec(\n          features={'x': values_3d},\n          mode=ModeKeys.TRAIN,\n          logits=values_1d,\n          labels=values_3d,\n          train_op_fn=lambda x: x,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n    labels_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    logits_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    spec = head.create_estimator_spec(\n        features={'x': values_1d},\n        mode=ModeKeys.TRAIN,\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        train_op_fn=lambda x: x,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      with self.assertRaisesRegexp(tf.errors.OpError, 'logits shape'):\n        spec.loss.eval({\n            labels_placeholder: values_3d,\n            logits_placeholder: values_1d\n        })\n    regularized_training_loss = head.loss(\n        logits=logits_placeholder,\n        labels=labels_placeholder,\n        features={'x': values_1d},\n        mode=ModeKeys.TRAIN)\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[expected_labels_shape: \\] \\[2 3\\] \\[labels_shape: \\] \\[2 1\\]'):\n        regularized_training_loss.eval({\n            labels_placeholder: values_1d,\n            logits_placeholder: values_3d\n        })\n\n  def test_predict(self):\n    head = head_lib.RegressionHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    logits = np.array(((45,), (41,),), dtype=np.int32)\n    preds = head.predictions(logits)\n\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), preds.keys())\n    predictions = preds[prediction_key]\n    self.assertEqual(tf.dtypes.float32, predictions.dtype)\n    self.assertAllClose(logits, self.evaluate(predictions))\n\n    if tf.executing_eagerly():\n      return\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42.,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=np.array(((45,), (41,),), dtype=np.int32),\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    self.assertIsNone(spec.loss)\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNone(spec.train_op)\n    default_serving_key = (tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)\n    self.assertItemsEqual((default_serving_key, 'predict', 'regression'),\n                          spec.export_outputs.keys())\n    test_lib._assert_no_hooks(self, spec)\n    # Assert predictions.\n    with self.cached_session():\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertAllClose(logits,\n                          spec.export_outputs[default_serving_key].value.eval())\n      self.assertAllClose(logits,\n                          spec.export_outputs['regression'].value.eval())\n      self.assertAllClose(\n          logits, spec.export_outputs['predict'].outputs['predictions'].eval())\n\n  def test_predict_with_inverse_link_fn(self):\n\n    def _inverse_link_fn(logits):\n      return logits - 10.\n\n    head = head_lib.RegressionHead(inverse_link_fn=_inverse_link_fn)\n\n    logits = np.array(((45,), (41,),), dtype=np.int32)\n    preds = head.predictions(logits)\n\n    keys = prediction_keys.PredictionKeys\n    self.assertItemsEqual((keys.PREDICTIONS, keys.LOGITS), preds.keys())\n    self.assertEqual(tf.dtypes.float32, preds[keys.PREDICTIONS].dtype)\n    self.assertEqual(tf.dtypes.float32, preds[keys.LOGITS].dtype)\n\n    expected_predictions = np.array(((35,), (31,),), dtype=np.int32)\n    self.assertAllClose(expected_predictions,\n                        self.evaluate(preds[keys.PREDICTIONS]))\n    self.assertAllClose(logits, self.evaluate(preds[keys.LOGITS]))\n\n    if tf.executing_eagerly():\n      return\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42.,),), dtype=np.int32)},\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    # Assert spec contains expected tensors.\n    default_serving_key = (tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY)\n    self.assertItemsEqual((default_serving_key, 'predict', 'regression'),\n                          spec.export_outputs.keys())\n    # Assert predictions.\n    with self.cached_session():\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertAllClose(expected_predictions,\n                          spec.export_outputs[default_serving_key].value.eval())\n      self.assertAllClose(expected_predictions,\n                          spec.export_outputs['regression'].value.eval())\n      self.assertAllClose(\n          expected_predictions,\n          spec.export_outputs['predict'].outputs['predictions'].eval())\n      self.assertAllClose(\n          logits, spec.export_outputs['predict'].outputs['logits'].eval())\n\n  def test_eval_create_loss(self):\n    head = head_lib.RegressionHead()\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n\n    regularized_training_loss = head.loss(\n        logits=logits, labels=labels, features=features)\n    self.assertAllClose(6.5, self.evaluate(regularized_training_loss))\n\n  def test_eval_create_loss_loss_fn(self):\n    \"\"\"Tests head.loss for eval mode and custom loss_fn.\"\"\"\n    loss = np.array([[0., 1.], [2., 3.]], dtype=np.float32)\n    batch_size = 4.\n    logits_input = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32)\n    labels_input = np.array([[1., 0.], [2., -1.]], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      check_labels = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(labels, labels_input)), data=[labels])\n      check_logits = tf.debugging.Assert(\n          tf.reduce_all(tf.math.equal(logits, logits_input)), data=[logits])\n      with tf.control_dependencies([check_labels, check_logits]):\n        return tf.constant(loss)\n\n    head = head_lib.RegressionHead(label_dimension=2, loss_fn=_loss_fn)\n    regularized_training_loss = head.loss(\n        logits=logits_input,\n        labels=labels_input,\n        features={'x': np.array(((42,),), dtype=np.int32)})\n    self.assertAllClose(\n        np.sum(loss) / batch_size, self.evaluate(regularized_training_loss))\n\n  def test_eval_create_loss_loss_fn_wrong_shape(self):\n    \"\"\"Tests custom loss_fn that returns Tensor of unexpected shape.\"\"\"\n    loss = np.array([[1.], [2.]], dtype=np.float32)\n\n    def _loss_fn(labels, logits):\n      del labels, logits  # Unused\n      return tf.constant(loss)\n\n    head = head_lib.RegressionHead(label_dimension=2, loss_fn=_loss_fn)\n\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n    logits = np.array([[-1., 1.], [-2., 2.]], dtype=np.float32)\n    labels = np.array([[1., 0.], [2., -1.]], dtype=np.float32)\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'loss_shape'):\n        head.loss(logits=logits, labels=labels, features=features)\n    else:\n      regularized_training_loss = head.loss(\n          logits=logits, labels=labels, features=features)\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[loss_fn must return Tensor of shape \\[D0, D1, ... DN, 2\\]\\. \\] '\n          r'\\[logits_shape: \\] \\[2 2\\] \\[loss_shape: \\] \\[2 1\\]'):\n        self.evaluate(regularized_training_loss)\n\n  def test_eval_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib.RegressionHead()\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.EVAL,\n          logits=np.array(((45,), (41,),), dtype=np.float32),\n          labels=None,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n\n  def test_eval(self):\n    head = head_lib.RegressionHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    preds = head.predictions(logits)\n    self.assertItemsEqual((prediction_key,), preds.keys())\n    self.assertEqual(tf.dtypes.float32, preds[prediction_key].dtype)\n    self.assertAllClose(logits, self.evaluate(preds[prediction_key]))\n\n    # weighted_loss = (43-45)^2 + (44-41)^2 = 13.\n    # loss = weighted_loss / batch_size = (4+9) / 2 = 6.5\n    expected_loss = 6.5\n    # loss_mean = loss/sum(weights) = 13/2 = 6.5\n    expected_loss_mean = 6.5\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics()\n      update_metrics = head.update_metrics(eval_metrics, features, logits,\n                                           labels)\n      self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,\n                             metric_keys.MetricKeys.PREDICTION_MEAN,\n                             metric_keys.MetricKeys.LABEL_MEAN),\n                            update_metrics.keys())\n      self.assertAllClose(\n          expected_loss_mean,\n          update_metrics[metric_keys.MetricKeys.LOSS_MEAN].result())\n      loss = head.loss(labels, logits, features=features, mode=ModeKeys.EVAL)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss)\n    else:\n      # Create estimator spec.\n      spec = head.create_estimator_spec(\n          features=features,\n          mode=ModeKeys.EVAL,\n          logits=logits,\n          labels=labels,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n      # Assert spec contains expected tensors.\n      self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n      self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,\n                             metric_keys.MetricKeys.PREDICTION_MEAN,\n                             metric_keys.MetricKeys.LABEL_MEAN),\n                            spec.eval_metric_ops.keys())\n      self.assertIsNone(spec.train_op)\n      self.assertIsNone(spec.export_outputs)\n      test_lib._assert_no_hooks(self, spec)\n      # Assert predictions, loss, and metrics.\n      with self.cached_session() as sess:\n        test_lib._initialize_variables(self, spec.scaffold)\n        self.assertIsNone(spec.scaffold.summary_op)\n        loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[\n            metric_keys.MetricKeys.LOSS_MEAN]\n        loss, _ = sess.run((spec.loss, loss_mean_update_op))\n        self.assertAllClose(6.5, loss)\n        # Check results of value ops (in `loss_mean`).\n        self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval())\n\n  def test_eval_metric_ops_with_head_name_for_regression(self):\n    head = head_lib.RegressionHead(name='some_regression_head')\n    logits = np.array(((1,), (9,)), dtype=np.float32)\n    labels = np.array(((1,), (1,)), dtype=np.int64)\n    features = {'x': np.array(((42,),), dtype=np.int32)}\n\n    expected_metric_keys = [\n        '{}/some_regression_head'.format(metric_keys.MetricKeys.LOSS_MEAN),\n        '{}/some_regression_head'.format(\n            metric_keys.MetricKeys.PREDICTION_MEAN),\n        '{}/some_regression_head'.format(metric_keys.MetricKeys.LABEL_MEAN),\n    ]\n    eval_metrics = head.metrics()\n    updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                          labels)\n    self.assertItemsEqual(expected_metric_keys, updated_metrics.keys())\n\n  def test_eval_with_regularization_losses(self):\n    head = head_lib.RegressionHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = ((43-45)^2 + (44-41)^2) / batch_size\n    #                    = (4 + 9) / 2 = 6.5\n    expected_unregularized_loss = 6.5\n    expected_regularized_loss = (\n        expected_unregularized_loss + expected_regularization_loss)\n\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_unregularized_loss,\n        keys.LOSS_REGULARIZATION: expected_regularization_loss,\n        keys.PREDICTION_MEAN: (45 + 41) / 2.0,\n        keys.LABEL_MEAN: (43 + 44) / 2.0,\n    }\n    # Test eval metrics in eager mode\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics(regularization_losses=regularization_losses)\n      updated_metrics = head.update_metrics(\n          eval_metrics,\n          features,\n          logits,\n          labels,\n          regularization_losses=regularization_losses)\n      # Assert metrics.\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics})\n    else:\n      # Create estimator spec.\n      spec = head.create_estimator_spec(\n          features=features,\n          mode=ModeKeys.EVAL,\n          logits=logits,\n          labels=labels,\n          regularization_losses=regularization_losses,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n      # Assert predictions, loss, and metrics.\n      with self.cached_session() as sess:\n        test_lib._initialize_variables(self, spec.scaffold)\n        self.assertIsNone(spec.scaffold.summary_op)\n        value_ops = {\n            k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops\n        }\n        update_ops = {\n            k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops\n        }\n        prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n        predictions, loss, _ = sess.run(\n            (spec.predictions[prediction_key], spec.loss, update_ops))\n        self.assertAllClose(logits, predictions)\n        self.assertAllClose(expected_regularized_loss, loss)\n        # Check results of value ops (in `metrics`).\n        self.assertAllClose(expected_metrics,\n                            {k: value_ops[k].eval() for k in value_ops})\n\n  def test_train_create_loss(self):\n    head = head_lib.RegressionHead()\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # training_loss = (1 * 4 + 1 * 9) / 2 = 6.5\n    expected_training_loss = 6.5\n    # Create loss.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertAllClose(expected_training_loss, self.evaluate(training_loss))\n\n  def test_train_create_loss_loss_reduction(self):\n    \"\"\"Tests create_loss with loss_reduction.\"\"\"\n    head = head_lib.RegressionHead(loss_reduction=tf.losses.Reduction.SUM)\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43,), (44,),), dtype=np.int32)\n    features = {'x': np.array(((42,),), dtype=np.float32)}\n    # training_loss = (1 * 4 + 1 * 9)\n    expected_training_loss = 13.\n    # Create loss.\n    training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertAllClose(expected_training_loss, self.evaluate(training_loss))\n\n  def test_train_labels_none(self):\n    \"\"\"Tests that error is raised when labels is None.\"\"\"\n    head = head_lib.RegressionHead()\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    with self.assertRaisesRegexp(\n        ValueError, r'You must provide a labels Tensor\\. Given: None\\.'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array((\n              (45,),\n              (41,),\n          ), dtype=np.float32),\n          labels=None,\n          train_op_fn=_no_op_train_fn,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n\n  def test_train(self):\n    head = head_lib.RegressionHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43.,), (44.,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42.,),), dtype=np.float32)}\n    # loss = ((43-45)^2 + (44-41)^2) / 2 = (4 + 9) / 2 = 6.5\n    expected_loss = 6.5\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    preds = head.predictions(logits)\n    loss = head.loss(labels, logits, features=features)\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), preds.keys())\n    self.assertEqual(tf.dtypes.float32, preds[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, loss.dtype)\n    self.assertAllClose(logits, self.evaluate(preds[prediction_key]))\n    self.assertAllClose(expected_loss, self.evaluate(loss))\n    if tf.executing_eagerly():\n      return\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert spec contains expected tensors.\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str)\n\n  def test_train_with_regularization_losses(self):\n    head = head_lib.RegressionHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43.,), (44.,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42.,),), dtype=np.float32)}\n    regularization_losses = [1.5, 0.5]\n    expected_regularization_loss = 2.\n    # unregularized_loss = ((43-45)^2 + (44-41)^2) / batch_size\n    #                    = (4 + 9) / 2 = 6.5\n    # loss = unregularized_loss + regularization_loss = 8.5\n    expected_loss = 8.5\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    loss = head.loss(\n        labels,\n        logits,\n        features=features,\n        mode=ModeKeys.TRAIN,\n        regularization_losses=regularization_losses)\n    preds = head.predictions(logits)\n    self.assertAllClose(logits, self.evaluate(preds[prediction_key]))\n    self.assertAllClose(expected_loss, self.evaluate(loss))\n    if tf.executing_eagerly():\n      return\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        regularization_losses=regularization_losses,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      test_lib._assert_simple_summaries(\n          self, {\n              metric_keys.MetricKeys.LOSS:\n                  expected_loss,\n              metric_keys.MetricKeys.LOSS_REGULARIZATION:\n                  (expected_regularization_loss),\n          }, summary_str)\n\n  def test_weighted_multi_example_eval(self):\n    \"\"\"1d label, 3 examples, 1 batch.\"\"\"\n    head = head_lib.RegressionHead(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n    logits = np.array(((45,), (41,), (44,)), dtype=np.int32)\n    features = {\n        'x': np.array(((42,), (43,), (44,)), dtype=np.int32),\n        'label_weights': np.array(((1.,), (.1,), (1.5,)), dtype=np.float32),\n    }\n    labels = np.array(((35,), (42,), (45,)), dtype=np.int32)\n\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    preds = head.predictions(logits)\n    self.assertItemsEqual((prediction_key,), preds.keys())\n    predictions = preds[prediction_key]\n    self.assertEqual(tf.dtypes.float32, predictions.dtype)\n    self.assertAllClose(logits, self.evaluate(predictions))\n\n    # loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2 = 100+.1+1.5 = 101.6\n    # expected_loss = loss / batch_size = 33.8666667\n    expected_loss = 33.8666667\n    # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.0769231\n    expected_loss_mean = 39.0769231\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,\n                             metric_keys.MetricKeys.PREDICTION_MEAN,\n                             metric_keys.MetricKeys.LABEL_MEAN),\n                            updated_metrics.keys())\n      self.assertAllClose(\n          expected_loss_mean,\n          updated_metrics[metric_keys.MetricKeys.LOSS_MEAN].result())\n      loss = head.loss(labels, logits, features=features, mode=ModeKeys.EVAL)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss)\n    else:\n      # Create estimator spec.\n      spec = head.create_estimator_spec(\n          features=features,\n          mode=ModeKeys.EVAL,\n          logits=logits,\n          labels=labels,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n      # Assert spec contains expected tensors.\n      self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n      self.assertIsNone(spec.train_op)\n      self.assertIsNone(spec.export_outputs)\n      test_lib._assert_no_hooks(self, spec)\n      # Assert predictions, loss, and metrics.\n      with self.cached_session() as sess:\n        test_lib._initialize_variables(self, spec.scaffold)\n        self.assertIsNone(spec.scaffold.summary_op)\n        loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[\n            metric_keys.MetricKeys.LOSS_MEAN]\n        loss, _ = sess.run((spec.loss, loss_mean_update_op))\n        self.assertAllClose(expected_loss, loss)\n        # Check results of value ops (in `loss_mean`).\n        self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval())\n\n  def test_weight_with_numeric_column(self):\n    \"\"\"1d label, 3 examples, 1 batch.\"\"\"\n    head = head_lib.RegressionHead(\n        weight_column=tf.feature_column.numeric_column(\n            'label_weights', normalizer_fn=lambda x: x + 1.))\n\n    logits = np.array(((45,), (41,), (44,)), dtype=np.int32)\n    features = {\n        'x': np.array(((42,), (43,), (44,)), dtype=np.int32),\n        'label_weights': np.array(((0.,), (-0.9,), (0.5,)), dtype=np.float32),\n    }\n    labels = np.array(((35,), (42,), (45,)), dtype=np.int32)\n\n    loss = head.loss(labels, logits, features=features)\n    # weighted_loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2\n    #               = 100+.1+1.5 = 101.6\n    # loss = weighted_loss / batch_size = 101.6 / 3 = 33.8666667\n    self.assertAllClose(33.8666667, self.evaluate(loss))\n\n  def test_weighted_multi_example_train(self):\n    \"\"\"1d label, 3 examples, 1 batch.\"\"\"\n    head = head_lib.RegressionHead(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n\n    features = {\n        'x': np.array(((42,), (43,), (44,)), dtype=np.float32),\n        'label_weights': np.array(((1.,), (.1,), (1.5,)), dtype=np.float64),\n    }\n    labels = np.array(((35.,), (42.,), (45.,)), dtype=np.float32)\n    logits = np.array(((45,), (41,), (44,)), dtype=np.float32)\n    expected_train_result = b'my_train_op'\n    # weighted_loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2\n    #               = 100+.1+1.5 = 101.6\n    # expected_loss = weighted_loss / batch_size = 101.6 / 3 = 33.8666667\n    expected_loss = 33.8666667\n    preds = head.predictions(logits)\n    loss = head.loss(labels, logits, features=features, mode=ModeKeys.TRAIN)\n\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), preds.keys())\n    self.assertEqual(tf.dtypes.float32, preds[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, loss.dtype)\n    self.assertAllClose(logits, self.evaluate(preds[prediction_key]))\n    self.assertAllClose(expected_loss, self.evaluate(loss))\n    if tf.executing_eagerly():\n      return\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert spec contains expected tensors.\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str)\n\n  def test_train_one_dim_create_loss(self):\n    \"\"\"Tests create_loss with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib.RegressionHead(weight_column='label_weights')\n    logits = np.array(((45,), (41,), (44,)), dtype=np.float32)\n    x_feature_rank_1 = np.array((42., 43., 44.,), dtype=np.float32)\n    weight_rank_1 = np.array((1., .1, 1.5,), dtype=np.float64)\n    labels_rank_1 = np.array((35., 42., 45.,))\n    # training_loss = (100 * 1 + 1 * .1 + 1.5 * 1) / batch_size\n    #               = 101.6 / 3 = 33.8666667\n    expected_training_loss = 33.8666667\n    features = {'x': x_feature_rank_1, 'label_weights': weight_rank_1}\n    # Create loss.\n    training_loss = head.loss(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1)\n    self.assertAllClose(expected_training_loss, self.evaluate(training_loss))\n\n  def test_train_one_dim(self):\n    \"\"\"Tests train with 1D labels and weights (shape [batch_size]).\"\"\"\n    head = head_lib.RegressionHead(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n\n    logits = np.array(((45,), (41,), (44,)), dtype=np.float32)\n    expected_train_result = b'my_train_op'\n    # loss = (1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2) / batch_size\n    #      = (100+.1+1.5) / 3 = 101.6 / 3 = 33.8666667\n    expected_loss = 33.8666667\n    x_feature_rank_1 = np.array((42., 43., 44.,), dtype=np.float32)\n    weight_rank_1 = np.array((1., .1, 1.5,), dtype=np.float64)\n    labels_rank_1 = np.array((35., 42., 45.,))\n    features = {'x': x_feature_rank_1, 'label_weights': weight_rank_1}\n    self.assertEqual((3,), x_feature_rank_1.shape)\n    self.assertEqual((3,), weight_rank_1.shape)\n    self.assertEqual((3,), labels_rank_1.shape)\n    preds = head.predictions(logits)\n    loss = head.loss(\n        labels=labels_rank_1,\n        logits=logits,\n        features=features,\n        mode=ModeKeys.TRAIN)\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), preds.keys())\n    self.assertEqual(tf.dtypes.float32, preds[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, loss.dtype)\n    self.assertAllClose(logits, self.evaluate(preds[prediction_key]))\n    self.assertAllClose(expected_loss, self.evaluate(loss))\n    if tf.executing_eagerly():\n      return\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels_rank_1,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert spec contains expected tensors.\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Assert predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str)\n\n  def test_weighted_multi_value_eval_create_loss(self):\n    \"\"\"3d label, 1 example, 1 batch.\"\"\"\n    head = head_lib.RegressionHead(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array(((45., 41., 44.),))\n    labels = np.array(((35., 42., 45.),))\n    features = {\n        'x': np.array(((42., 43., 44.),)),\n        'label_weights': np.array(((1., .1, 1.5),))\n    }\n    regularized_training_loss = head.loss(\n        logits=logits, labels=labels, features=features)\n    # loss = [(35-45)^2, (42-41)^2, (45-44)^2] = [100, 1, 1].\n    # weighted_sum_loss = 1 * 100 + .1 * 1 + 1.5 * 1 = 101.6\n    # expected_training_loss = weighted_sum_loss / batch_size\n    #                        = 101.6 / 3 = 33.8666667\n    self.assertAllClose(33.8666667, self.evaluate(regularized_training_loss))\n\n  def test_weighted_multi_value_eval(self):\n    \"\"\"3d label, 1 example, 1 batch.\"\"\"\n    head = head_lib.RegressionHead(\n        weight_column='label_weights', label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n\n    logits = np.array(((45., 41., 44.),))\n    labels = np.array(((35., 42., 45.),))\n    features = {\n        'x': np.array(((42., 43., 44.),)),\n        'label_weights': np.array(((1., .1, 1.5),))\n    }\n\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    preds = head.predictions(logits)\n    self.assertItemsEqual((prediction_key,), preds.keys())\n    predictions = preds[prediction_key]\n    self.assertEqual(tf.dtypes.float32, predictions.dtype)\n    self.assertAllClose(logits, self.evaluate(predictions))\n\n    # weighted_loss = 1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2\n    #               = 100+.1+1.5 = 101.6\n    # expected_loss = weighted_loss / batch_size = 101.6 / 3 = 33.8666667\n    expected_loss = 33.8666667\n    # loss_mean = weighted_loss/(1+.1+1.5) = 101.6/2.6 = 39.0769231\n    expected_loss_mean = 39.0769231\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics()\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels)\n      self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,\n                             metric_keys.MetricKeys.PREDICTION_MEAN,\n                             metric_keys.MetricKeys.LABEL_MEAN),\n                            updated_metrics.keys())\n      self.assertAllClose(\n          expected_loss_mean,\n          updated_metrics[metric_keys.MetricKeys.LOSS_MEAN].result())\n      loss = head.loss(labels, logits, features=features, mode=ModeKeys.EVAL)\n      self.assertIsNotNone(loss)\n      self.assertAllClose(expected_loss, loss)\n    else:\n      # Create estimator spec.\n      spec = head.create_estimator_spec(\n          features=features,\n          mode=ModeKeys.EVAL,\n          logits=logits,\n          labels=labels,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n      # Assert spec contains expected tensors.\n      self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n      self.assertItemsEqual((metric_keys.MetricKeys.LOSS_MEAN,\n                             metric_keys.MetricKeys.PREDICTION_MEAN,\n                             metric_keys.MetricKeys.LABEL_MEAN),\n                            spec.eval_metric_ops.keys())\n      self.assertIsNone(spec.train_op)\n      self.assertIsNone(spec.export_outputs)\n      test_lib._assert_no_hooks(self, spec)\n      # Assert predictions, loss, and metrics.\n      with self.cached_session() as sess:\n        test_lib._initialize_variables(self, spec.scaffold)\n        self.assertIsNone(spec.scaffold.summary_op)\n        loss_mean_value_op, loss_mean_update_op = spec.eval_metric_ops[\n            metric_keys.MetricKeys.LOSS_MEAN]\n        loss, _ = sess.run((spec.loss, loss_mean_update_op))\n        self.assertAllClose(expected_loss, loss)\n        # Check results of value ops (in `loss_mean`).\n        self.assertAllClose(expected_loss_mean, loss_mean_value_op.eval())\n\n  def test_weighted_multi_value_train_create_loss(self):\n    \"\"\"3d label, 1 example, 1 batch.\"\"\"\n    head = head_lib.RegressionHead(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array(((45., 41., 44.),))\n    labels = np.array(((35., 42., 45.),))\n    features = {\n        'x': np.array(((42., 43., 44.),)),\n        'label_weights': np.array(((1., .1, 1.5),))\n    }\n\n    # Create loss.\n    regularized_training_loss = head.loss(\n        logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n    self.assertAllClose(33.8666667, self.evaluate(regularized_training_loss))\n\n  def test_weighted_multi_value_train(self):\n    \"\"\"3d label, 1 example, 1 batch.\"\"\"\n    head = head_lib.RegressionHead(\n        weight_column='label_weights', label_dimension=3)\n    self.assertEqual(3, head.logits_dimension)\n\n    logits = np.array(((45., 41., 44.),))\n    labels = np.array(((35., 42., 45.),))\n    expected_train_result = b'my_train_op'\n    # loss = (1*(35-45)^2 + .1*(42-41)^2 + 1.5*(45-44)^2) / batch_size\n    #      = (100+.1+1.5) / 3 = 101.6 / 3 = 33.8666667\n    expected_loss = 33.8666667\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    features = {\n        'x': np.array(((42., 43., 44.),)),\n        'label_weights': np.array(((1., .1, 1.5),)),\n    }\n    preds = head.predictions(logits)\n    loss = head.loss(labels, logits, features=features, mode=ModeKeys.TRAIN)\n    prediction_key = prediction_keys.PredictionKeys.PREDICTIONS\n    self.assertItemsEqual((prediction_key,), preds.keys())\n    self.assertEqual(tf.dtypes.float32, preds[prediction_key].dtype)\n    self.assertEqual(tf.dtypes.float32, loss.dtype)\n    self.assertAllClose(logits, self.evaluate(preds[prediction_key]))\n    self.assertAllClose(expected_loss, self.evaluate(loss))\n    if tf.executing_eagerly():\n      return\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert spec contains expected tensors.\n    self.assertEqual({}, spec.eval_metric_ops)\n    self.assertIsNotNone(spec.train_op)\n    self.assertIsNone(spec.export_outputs)\n    test_lib._assert_no_hooks(self, spec)\n\n    # Evaluate predictions, loss, train_op, and summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      predictions, loss, train_result, summary_str = sess.run(\n          (spec.predictions[prediction_key], spec.loss, spec.train_op,\n           spec.scaffold.summary_op))\n      self.assertAllClose(logits, predictions)\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n      test_lib._assert_simple_summaries(self, {\n          metric_keys.MetricKeys.LOSS: expected_loss,\n      }, summary_str)\n\n  def test_weighted_multi_batch_eval_eager(self):\n    \"\"\"1d label, 1 example, 3 batches.\"\"\"\n    with tf.compat.v2.__internal__.eager_context.eager_mode():\n      head = head_lib.RegressionHead(weight_column='label_weights')\n      self.assertEqual(1, head.logits_dimension)\n\n      logits = np.array(((45.,), (41.,), (44.,)))\n      features = {\n          'x': np.array(((42.,), (43.,), (44.,))),\n          'label_weights': np.array(((1.,), (.1,), (1.5,))),\n          # 'logits' is not a feature, but we use `tf.data.Dataset` to make it\n          # as a `tensor` (required by `update_metrics`), and access it\n          # via `features['logits']` in `update_metrics`\n          'logits': logits\n      }\n      labels = np.array(((35.,), (42.,), (45.,)))\n\n      # losses = [1*(35-45)^2, .1*(42-41)^2, 1.5*(45-44)^2] = [100, .1, 1.5]\n      # loss = sum(losses) = 100+.1+1.5 = 101.6\n      # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.076923\n      expected_metrics = {\n          metric_keys.MetricKeys.LOSS_MEAN: 39.076923,\n          metric_keys.MetricKeys.PREDICTION_MEAN:\n              (45 + 41 * 0.1 + 44 * 1.5) / 2.6,\n          metric_keys.MetricKeys.LABEL_MEAN: (35 + 42 * 0.1 + 45 * 1.5) / 2.6,\n      }\n      dataset = tf.compat.v1.data.Dataset.from_tensor_slices((features, labels))\n      dataset = dataset.batch(1)\n      eval_metrics = head.metrics()\n      for (features, labels) in dataset:\n        logits = features['logits']\n        updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                              labels)\n        # Assert metrics.\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics})\n\n  def test_weighted_multi_batch_train_eager(self):\n    \"\"\"1d label, 1 example, 3 batches.\"\"\"\n    if tf.executing_eagerly():\n      head = head_lib.RegressionHead(weight_column='label_weights')\n      self.assertEqual(1, head.logits_dimension)\n\n      logits = np.array(((45.,), (41.,), (44.,)))\n      features = {\n          'x': np.array(((42.,), (43.,), (44.,))),\n          'label_weights': np.array(((1.,), (.1,), (1.5,))),\n          # 'logits' is not a feature, but we use `tf.data.Dataset` to make it\n          # as a `tensor` (required by `update_metrics`), and access it\n          # via `features['logits']` in `update_metrics`\n          'logits': logits\n      }\n      labels = np.array(((35.,), (42.,), (45.,)))\n      dataset = tf.compat.v1.data.Dataset.from_tensor_slices((features, labels))\n      dataset = dataset.batch(1)\n      expected_losses = np.array((100, .1, 1.5))\n      for (batch, (features, labels)) in enumerate(dataset):\n        logits = features['logits']\n        loss = head.loss(labels, logits, features=features)\n        self.assertAllClose(expected_losses[batch], loss)\n\n  def test_multi_dim_weighted_train_create_loss(self):\n    \"\"\"Logits, labels of shape [2, 2, 3], weight shape [2, 2].\"\"\"\n    label_dimension = 3\n    head = head_lib.RegressionHead(\n        weight_column='label_weights', label_dimension=label_dimension)\n    logits = np.array([[[00., 01., 02.], [10., 11., 12.]],\n                       [[20., 21., 22.], [30., 31., 32.]]])\n    labels = np.array([[[01., 02., 03.], [12., 13., 14.]],\n                       [[23., 24., 25.], [34., 35., 36.]]])\n    weights = np.array([[1., 1.5], [2., 2.5]])\n    training_loss_weighted_sum = np.sum(\n        np.array([[[1. * x for x in [1., 1., 1.]],\n                   [1.5 * x for x in [4., 4., 4.]]],\n                  [[2. * x for x in [9., 9., 9.]],\n                   [2.5 * x for x in [16., 16., 16.]]]]))\n    # batch_size = 2 * 2 * 3 = 12.\n    # expected_training_loss = training_loss_weighted_sum / batch_size\n    # Create loss.\n    training_loss = head.loss(\n        features={'label_weights': weights},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        regularization_losses=None)\n    self.assertAllClose(training_loss_weighted_sum / 12.,\n                        self.evaluate(training_loss))\n\n  def test_multi_dim_weighted_train(self):\n    \"\"\"Logits, labels of shape [2, 2, 3], weight shape [2, 2].\"\"\"\n    head = head_lib.RegressionHead(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array([[[00., 01., 02.], [10., 11., 12.]],\n                       [[20., 21., 22.], [30., 31., 32.]]])\n    labels = np.array([[[01., 02., 03.], [12., 13., 14.]],\n                       [[23., 24., 25.], [34., 35., 36.]]])\n    expected_train_result = b'my_train_op'\n    features = {\n        'label_weights': np.array([[1., 1.5], [2., 2.5]]),\n    }\n    # weighted_loss_sum = (1*3*1^2 + 1.5*3*2^2 + 2*3*3^2 +2.5*3*4^2) = 195\n    # loss = weighted_loss_sum / batch_size = 195 / (2*2*3) = 16.25\n    expected_loss = 16.25\n    loss = head.loss(labels, logits, features=features, mode=ModeKeys.TRAIN)\n    self.assertAllClose(expected_loss, self.evaluate(loss))\n    if tf.executing_eagerly():\n      return\n\n    # Create estimator spec.\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      self.assertAllClose(expected_loss, spec.loss.eval())\n\n  def test_multi_dim_train_weights_wrong_inner_dim(self):\n    \"\"\"Logits, labels of shape [2, 2, 3], weight shape [2, 1].\"\"\"\n    head = head_lib.RegressionHead(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array([[[00., 01., 02.], [10., 11., 12.]],\n                       [[20., 21., 22.], [30., 31., 32.]]])\n    labels = np.array([[[01., 02., 03.], [12., 13., 14.]],\n                       [[23., 24., 25.], [34., 35., 36.]]])\n    features = {\n        'label_weights': np.array([[1.], [2]]),\n    }\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'weights shape'):\n        head.loss(\n            features=features,\n            mode=ModeKeys.TRAIN,\n            logits=logits,\n            labels=labels,\n            regularization_losses=None)\n      return\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.assertRaisesRegexp(\n        tf.errors.InvalidArgumentError,\n        r'\\[logits_shape: \\] \\[2 2 3\\] \\[weights_shape: \\] \\[2 1\\]'):\n      self.evaluate(spec.loss)\n\n  def test_multi_dim_train_weights_wrong_outer_dim(self):\n    \"\"\"Logits, labels of shape [2, 2, 3], weight shape [2, 2, 2].\"\"\"\n    head = head_lib.RegressionHead(\n        weight_column='label_weights', label_dimension=3)\n    logits = np.array([[[00., 01., 02.], [10., 11., 12.]],\n                       [[20., 21., 22.], [30., 31., 32.]]])\n    labels = np.array([[[01., 02., 03.], [12., 13., 14.]],\n                       [[23., 24., 25.], [34., 35., 36.]]])\n\n    def _no_op_train_fn(loss):\n      del loss\n      return tf.no_op()\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError, 'weights shape'):\n        head.loss(\n            features={\n                'label_weights':\n                    np.array([[[1., 1.1], [1.5, 1.6]], [[2., 2.1], [2.5, 2.6]]])\n            },\n            mode=ModeKeys.TRAIN,\n            logits=logits,\n            labels=labels,\n            regularization_losses=None)\n      return\n\n    weights_placeholder = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n    features = {\n        'label_weights': weights_placeholder,\n    }\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_no_op_train_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    with self.cached_session():\n      test_lib._initialize_variables(self, tf.compat.v1.train.Scaffold())\n      with self.assertRaisesRegexp(\n          tf.errors.InvalidArgumentError,\n          r'\\[logits_shape: \\]\\s\\[2 2 3\\]\\s\\[weights_shape: \\]\\s\\[2 2 2\\]'):\n        spec.loss.eval({\n            weights_placeholder:\n                np.array([[[1., 1.1], [1.5, 1.6]], [[2., 2.1], [2.5, 2.6]]])\n        })\n\n\n@test_util.deprecated_graph_mode_only\nclass RegressionHeadForEstimator(tf.test.TestCase):\n  \"\"\"Tests for create_estimator_spec running in Graph mode only.\"\"\"\n\n  def test_invalid_trainable_variables(self):\n    head = head_lib.RegressionHead()\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params\n        return [\n            tf.strings.join([\n                tf.constant('my_train_op'),\n                tf.strings.as_string(loss, precision=2)\n            ])\n        ]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    with self.assertRaisesRegexp(ValueError,\n                                 r'trainable_variables cannot be None'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42.,),), dtype=np.float32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (41,),), dtype=np.float32),\n          labels=np.array(((43.,), (44.,),), dtype=np.float64),\n          optimizer=_Optimizer('my_optimizer'),\n          trainable_variables=None)\n    with self.assertRaisesRegexp(\n        ValueError, r'trainable_variables should be a list or a tuple'):\n      head.create_estimator_spec(\n          features={'x': np.array(((42.,),), dtype=np.float32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (41,),), dtype=np.float32),\n          labels=np.array(((43.,), (44.,),), dtype=np.float64),\n          optimizer=_Optimizer('my_optimizer'),\n          trainable_variables={\n              'var_list': [tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)]\n          })\n\n  def test_train_with_optimizer(self):\n    head = head_lib.RegressionHead()\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43.,), (44.,),), dtype=np.float64)\n    expected_train_result = b'my_train_op'\n    features = {'x': np.array(((42.,),), dtype=np.float32)}\n    # loss = ((43-45)^2 + (44-41)^2) / 2 = (4 + 9) / 2 = 13 / 2 = 6.5\n    expected_loss = 6.5\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params\n        with tf.control_dependencies((tf.compat.v1.debugging.assert_equal(\n            tf.cast(expected_loss, dtype=tf.dtypes.float32),\n            tf.cast(loss, dtype=tf.dtypes.float32),\n            name='assert_loss'),)):\n          return [tf.constant(expected_train_result)]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        optimizer=_Optimizer('my_optimizer'),\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run((spec.loss, spec.train_op))\n      self.assertAllClose(expected_loss, loss)\n      self.assertEqual(expected_train_result, train_result)\n\n  def test_train_with_update_ops(self):\n    with tf.Graph().as_default():\n      w = tf.Variable(1)\n      update_op = w.assign_add(1)\n\n      t = tf.Variable('')\n      expected_train_result = b'my_train_op'\n\n      def _train_op_fn(loss):\n        del loss\n        return t.assign(expected_train_result)\n\n      head = head_lib.RegressionHead()\n      spec = head.create_estimator_spec(\n          features={'x': np.array(((42,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=np.array(((45,), (41,),), dtype=np.float32),\n          labels=np.array(((43.,), (44.,),), dtype=np.float64),\n          update_ops=[update_op],\n          train_op_fn=_train_op_fn,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n\n      with self.cached_session() as sess:\n        test_lib._initialize_variables(self, spec.scaffold)\n        sess.run(spec.train_op)\n        w_value, t_value = sess.run([w, t])\n        self.assertEqual(2, w_value)\n        self.assertEqual(expected_train_result, t_value)\n\n  def test_train_summaries_with_head_name(self):\n    head = head_lib.RegressionHead(name='some_regression_head')\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45,), (41,),), dtype=np.float32)\n    labels = np.array(((43.,), (44.,),), dtype=np.float64)\n    features = {'x': np.array(((42.,),), dtype=np.float32)}\n    # loss = ((43-45)^2 + (44-41)^2) / 2 = (4 + 9) / 2 = 6.5\n    expected_loss = 6.5\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.no_op()\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert summaries.\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      summary_str = sess.run(spec.scaffold.summary_op)\n      test_lib._assert_simple_summaries(\n          self, {\n              '{}/some_regression_head'.format(metric_keys.MetricKeys.LOSS):\n                  expected_loss,\n          }, summary_str)\n\n  def test_weighted_multi_batch_train(self):\n    \"\"\"1d label, 1 example, 3 batches.\"\"\"\n    # numpy_input_fn is not compitable with eager.\n    head = head_lib.RegressionHead(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45.,), (41.,), (44.,)))\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'x': np.array(((42.,), (43.,), (44.,))),\n            'label_weights': np.array(((1.,), (.1,), (1.5,))),\n            # 'logits' is not a feature, but we use `numpy_input_fn` to make a\n            # batched version of it, and pop it off before passing to\n            # `create_estimator_spec`.\n            'logits': logits,\n        },\n        y=np.array(((35.,), (42.,), (45.,))),\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    batched_features, batched_labels = input_fn()\n    batched_logits = batched_features.pop('logits')\n    spec = head.create_estimator_spec(\n        features=batched_features,\n        mode=ModeKeys.TRAIN,\n        logits=batched_logits,\n        labels=batched_labels,\n        train_op_fn=lambda loss: loss * -7.,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # Assert spec contains expected tensors.\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertIsNotNone(spec.train_op)\n\n    with self.cached_session() as sess:\n      # Finalize graph and initialize variables.\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      tf.compat.v1.train.queue_runner.start_queue_runners()\n\n      results = tuple(\n          [sess.run((spec.loss, spec.train_op)) for _ in range(len(logits))])\n\n      # losses = [1*(35-45)^2, .1*(42-41)^2, 1.5*(45-44)^2] = [100, .1, 1.5]\n      expected_losses = np.array((100, .1, 1.5))\n      self.assertAllClose(expected_losses, [r[0] for r in results])\n      self.assertAllClose(expected_losses * -7., [r[1] for r in results])\n\n  def test_weighted_multi_batch_eval(self):\n    \"\"\"1d label, 1 example, 3 batches.\"\"\"\n    # numpy_input_fn is not compitable with eager.\n    head = head_lib.RegressionHead(weight_column='label_weights')\n    self.assertEqual(1, head.logits_dimension)\n\n    # Create estimator spec.\n    logits = np.array(((45.,), (41.,), (44.,)))\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'x': np.array(((42.,), (43.,), (44.,))),\n            'label_weights': np.array(((1.,), (.1,), (1.5,))),\n            # 'logits' is not a feature, but we use `numpy_input_fn` to make a\n            # batched version of it, and pop it off before passing to\n            # `create_estimator_spec`.\n            'logits': logits,\n        },\n        y=np.array(((35.,), (42.,), (45.,))),\n        batch_size=1,\n        num_epochs=1,\n        shuffle=False)\n    batched_features, batched_labels = input_fn()\n    batched_logits = batched_features.pop('logits')\n    spec = head.create_estimator_spec(\n        features=batched_features,\n        mode=ModeKeys.EVAL,\n        logits=batched_logits,\n        labels=batched_labels,\n        train_op_fn=None,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    # losses = [1*(35-45)^2, .1*(42-41)^2, 1.5*(45-44)^2] = [100, .1, 1.5]\n    # loss = sum(losses) = 100+.1+1.5 = 101.6\n    # loss_mean = loss/(1+.1+1.5) = 101.6/2.6 = 39.076923\n    expected_metrics = {\n        metric_keys.MetricKeys.LOSS_MEAN:\n            39.076923,\n        metric_keys.MetricKeys.PREDICTION_MEAN:\n            (45 + 41 * 0.1 + 44 * 1.5) / 2.6,\n        metric_keys.MetricKeys.LABEL_MEAN: (35 + 42 * 0.1 + 45 * 1.5) / 2.6,\n    }\n\n    # Assert spec contains expected tensors.\n    self.assertEqual(tf.dtypes.float32, spec.loss.dtype)\n    self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())\n    self.assertIsNone(spec.train_op)\n    test_lib._assert_no_hooks(self, spec)\n\n    with self.cached_session() as sess:\n      # Finalize graph and initialize variables.\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNotNone(spec.scaffold.summary_op)\n      tf.compat.v1.train.queue_runner.start_queue_runners()\n\n      # Run tensors for `steps` steps.\n      steps = len(logits)\n      results = tuple([\n          sess.run((\n              spec.loss,\n              # The `[1]` gives us the metric update op.\n              {k: spec.eval_metric_ops[k][1]\n               for k in spec.eval_metric_ops}))\n          for _ in range(steps)\n      ])\n\n      # Assert losses and metrics.\n      self.assertAllClose((100, .1, 1.5), [r[0] for r in results])\n      # For metrics, check results of value ops (in `results`).\n      self.assertAllClose(\n          expected_metrics,\n          {k: spec.eval_metric_ops[k][0].eval() for k in spec.eval_metric_ops})\n\n\nclass PoissonRegressionHead(tf.test.TestCase):\n\n  def test_train(self):\n    head = head_lib.PoissonRegressionHead()\n\n    logits = np.array([[0], [-1], [1]], dtype=np.float32)\n    labels = np.array([[1], [2], [3]], dtype=np.int32)\n    features = {'x': np.array(((42.,),), dtype=np.int32)}\n    # With x = exp(logits), z = labels.\n    # loss = -ln(exp(-x) * (x^z) / z!)\n    #      = x - z * ln(x) + ln(z!)\n    #      = exp(logits) - labels * logits - ln(labels!)\n    # But for ln(z!) and z > 1, the Stirling approximation is used\n    # ln(z!) = z*ln(z) - z + 0.5*ln(2*pi*z)\n    # loss = [exp(0) - 1 * 0 + ln(1!),\n    #         exp(-1) - 2 * (-1) + 2*ln(2) - 2 + 0.5*ln(2*pi*2),\n    #         exp(1) - 3 * 1 + 3*ln(3) - 3 + 0.5*ln(2*pi*3)]\n    #      = [1.0, 3.020, 1.482]\n    # training_loss = (1.0 + 3.020 + 1.482) / 3\n    expected_loss = 1.834\n    atol = 0.001\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n      self.assertAlmostEqual(expected_loss, loss, delta=atol)\n      return\n\n    expected_train_result = b'my_train_op'\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_near(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          atol=atol,\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run([spec.loss, spec.train_op])\n      self.assertAlmostEqual(expected_loss, loss, delta=atol)\n      self.assertEqual(expected_train_result, train_result)\n\n  def test_predict(self):\n    head = head_lib.PoissonRegressionHead()\n\n    logits = np.array([[0], [-1], [1]], dtype=np.float32)\n    expected_predictions = np.exp(logits)\n    keys = prediction_keys.PredictionKeys\n\n    preds = head.predictions(logits)\n    self.assertItemsEqual((keys.PREDICTIONS, keys.LOGITS), preds.keys())\n    self.assertEqual(tf.dtypes.float32, preds[keys.PREDICTIONS].dtype)\n    self.assertEqual(tf.dtypes.float32, preds[keys.LOGITS].dtype)\n    self.assertAllClose(expected_predictions,\n                        self.evaluate(preds[keys.PREDICTIONS]))\n    self.assertAllClose(logits, self.evaluate(preds[keys.LOGITS]))\n\n\nclass LogisticRegressionHead(tf.test.TestCase):\n\n  def test_train(self):\n    head = head_lib.LogisticRegressionHead()\n\n    logits = np.array([[0], [-1], [1]], dtype=np.float32)\n    labels = np.array([[.4], [.6], [.8]], dtype=np.float32)\n    features = {'x': np.array(((42.,),), dtype=np.int32)}\n    # Following the documentation in\n    # tf.nn.sigmoid_cross_entropy_with_logits:\n    # With x = logits, z = labels.\n    # loss  = max(x, 0) - x * z + log(1 + exp(-abs(x)))\n    # loss = [0 - 0 * 0.4 + ln(1 + exp(-0)),\n    #         0 + 1 * 0.6 + ln(1 + exp(-1)),\n    #         1 - 1 * 0.8 + ln(1 + exp(-1))]\n    #      = [0.6931, 0.9133, 0.5133]\n    # training_loss = (0.6931 + 0.9133 + 0.5133) / 3\n    expected_loss = 0.7066\n    atol = 0.001\n    if tf.executing_eagerly():\n      loss = head.loss(\n          logits=logits, labels=labels, features=features, mode=ModeKeys.TRAIN)\n      self.assertAlmostEqual(expected_loss, loss, delta=atol)\n      return\n\n    expected_train_result = b'my_train_op'\n\n    def _train_op_fn(loss):\n      with tf.control_dependencies((tf.compat.v1.debugging.assert_near(\n          tf.cast(expected_loss, dtype=tf.dtypes.float32),\n          tf.cast(loss, dtype=tf.dtypes.float32),\n          atol=atol,\n          name='assert_loss'),)):\n        return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    spec = head.create_estimator_spec(\n        features={'x': np.array(((42.,),), dtype=np.int32)},\n        mode=ModeKeys.TRAIN,\n        logits=logits,\n        labels=labels,\n        train_op_fn=_train_op_fn,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      loss, train_result = sess.run([spec.loss, spec.train_op])\n      self.assertAlmostEqual(expected_loss, loss, delta=atol)\n      self.assertEqual(expected_train_result, train_result)\n\n  def test_train_labels_too_large(self):\n    head = head_lib.LogisticRegressionHead()\n\n    logits = np.array([[0], [-1], [1]], dtype=np.float32)\n    labels = np.array([[.4], [1.2], [.8]], dtype=np.float32)\n    features = {'x': np.array(((42.,),), dtype=np.int32)}\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError,\n                                   r'Labels must be in range \\[0, 1\\]'):\n        head.loss(\n            logits=logits,\n            labels=labels,\n            features=features,\n            mode=ModeKeys.TRAIN)\n      return\n\n    expected_train_result = b'my_train_op'\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,\n                                 r'Labels must be in range \\[0, 1\\]'):\n      spec = head.create_estimator_spec(\n          features=features,\n          mode=ModeKeys.TRAIN,\n          logits=logits,\n          labels=labels,\n          train_op_fn=_train_op_fn,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n\n  def test_train_labels_negative(self):\n    head = head_lib.LogisticRegressionHead()\n\n    logits = np.array([[0], [-1], [1]], dtype=np.float32)\n    labels = np.array([[.4], [-0.2], [.8]], dtype=np.float32)\n    features = {'x': np.array(((42.,),), dtype=np.int32)}\n\n    if tf.executing_eagerly():\n      with self.assertRaisesRegexp(ValueError,\n                                   r'Labels must be in range \\[0, 1\\]'):\n        head.loss(\n            logits=logits,\n            labels=labels,\n            features=features,\n            mode=ModeKeys.TRAIN)\n      return\n\n    expected_train_result = b'my_train_op'\n\n    def _train_op_fn(loss):\n      del loss\n      return tf.constant(expected_train_result)\n\n    # Create estimator spec.\n    with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,\n                                 r'Labels must be in range \\[0, 1\\]'):\n      spec = head.create_estimator_spec(\n          features={'x': np.array(((42.,),), dtype=np.int32)},\n          mode=ModeKeys.TRAIN,\n          logits=logits,\n          labels=labels,\n          train_op_fn=_train_op_fn,\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n\n  def test_predict(self):\n    head = head_lib.LogisticRegressionHead()\n\n    logits = np.array([[0], [-1], [1]], dtype=np.float32)\n    expected_predictions = 1. / (1. + np.exp(-logits))\n    keys = prediction_keys.PredictionKeys\n\n    preds = head.predictions(logits)\n    self.assertItemsEqual((keys.PREDICTIONS, keys.LOGITS), preds.keys())\n    self.assertEqual(tf.dtypes.float32, preds[keys.PREDICTIONS].dtype)\n    self.assertEqual(tf.dtypes.float32, preds[keys.LOGITS].dtype)\n    self.assertAllClose(expected_predictions,\n                        self.evaluate(preds[keys.PREDICTIONS]))\n    self.assertAllClose(logits, self.evaluate(preds[keys.LOGITS]))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/sequential_head.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Defines a head for sequential models.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport abc\n\nimport six\nimport tensorflow as tf\n\nif six.PY3:\n  from collections.abc import Iterable\nelse:\n  from collections import Iterable\n\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator.head import base_head\nfrom tensorflow_estimator.python.estimator.head import multi_head\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\nclass _SequentialHead(base_head.Head):\n  \"\"\"Interface for the head of a sequential model.\n\n  A sequential head handles input sequences of different lengths to compute the\n  output of a model. It requires a sequence mask tensor, to indicate which steps\n  of the sequences are padded and ensure proper aggregation for loss and metrics\n  computation. It has a `input_sequence_mask_key` property that specifies which\n  tensor of the feature dictionary to use as the sequence mask tensor.\n\n  Such a head can for instance be used with `RNNEstimator` for sequential\n  predictions.\n\n  Example of usage:\n    ```python\n    def _my_model_fn(features, labels, mode, params, config=None):\n      feature_layer = tf.feature_column.SequenceFeatureLayer(columns)\n      input_layer, sequence_length = feature_layer(features)\n      sequence_length_mask = tf.sequence_mask(sequence_length)\n      rnn_layer = tf_keras.layers.RNN(cell=tf_keras.layers.SimpleRNNCell(units),\n                                      return_sequences=True)\n      logits = rnn_layer(input_layer, mask=sequence_length_mask)\n      features[sequential_head.input_sequence_mask_key] = sequence_length_mask\n      return sequential_head.create_estimator_spec(\n          features=features,\n          labels=labels,\n          mode=mode,\n          logits=logits,\n          optimizer=optimizer)\n    ```\n  \"\"\"\n  __metaclass__ = abc.ABCMeta\n\n  @abc.abstractproperty\n  def input_sequence_mask_key(self):\n    \"\"\"Key of the sequence mask tensor in the feature dictionary.\n\n    Returns:\n      A string.\n    \"\"\"\n    raise NotImplementedError('Calling an abstract method.')\n\n\nclass SequentialHeadWrapper(_SequentialHead):\n  \"\"\"Sequential head wrapping a Head object.\n\n  Wraps a `Head` object and applies a sequential mask to:\n    - Loss aggregation: To only account for masked steps. Used for\n      `create_estimator_spec` and `loss` methods.\n    - Metrics: The sequence mask is used to only account for mask steps in\n      metrics computation with the `update_metrics` method.\n    - Predictions: To add a sequence length mask tensor to the predictions\n      dictionary.\n  \"\"\"\n\n  def __init__(self,\n               static_head,\n               sequence_length_mask='sequence_length_mask',\n               feature_columns=None):\n    \"\"\"Initializes a `SequentialHeadWrapper` instance.\n\n    Example of usage:\n      ```python\n      # Define a sequential head.\n      static_head = tf.estimator.BinaryClassHead(weight_column='weights')\n      sequential_head = head_lib.SequentialHeadWrapper(\n          static_head=static_head, sequence_length_mask='mask',\n          feature_columns='weights')\n\n      # Define feature columns and parsing spec.\n      feature_columns = [\n        tf.feature_column.sequence_numeric_column('sequential-feature')\n      ]\n      label_column = tf.feature_column.sequence_numeric_column(\n          'label', dtype=tf.int32),\n      weight_column = tf.feature_column.sequence_numeric_column('weights')\n      parsing_spec = tf.feature_column.make_parse_example_spec(\n          feature_columns + [label_column, weight_column])\n\n      # Use the head in a model function.\n      def _my_model_fn(features, labels, mode, params, config=None):\n        feature_layer = tf.feature_column.SequenceFeatureLayer(feature_columns)\n        input_layer, sequence_length = feature_layer(features)\n        sequence_length_mask = tf.sequence_mask(sequence_length)\n        rnn_layer = tf_keras.layers.RNN(\n            cell=tf_keras.layers.SimpleRNNCell(units),\n            return_sequences=True)\n        logits = rnn_layer(input_layer, mask=sequence_length_mask)\n        features['mask'] = sequence_length_mask\n        return sequential_head.create_estimator_spec(\n            features=features,\n            labels=labels,\n            mode=mode,\n            logits=logits,\n            optimizer=optimizer)\n      ```\n\n    Args:\n      static_head: `Head` object, static head to wrap.\n      sequence_length_mask: `str`, name of sequence length mask tensor in\n        features dictionary. Tensor must be a dense tensor of shape [batch_size,\n        seq_length].\n      feature_columns: `str` or list of the former. Specifies the features of\n        the features dictionary to which the sequence length mask must be\n        applied, and which are passed to the static head's methods when calling\n        `create_estimator_spec`, `loss` or `update_metrics`. This is typically a\n        weight tensor.\n\n    Raises:\n      TypeError: If `sequence_length_mask` is not of string type.\n      TypeError: If provided features columns are not of string type.\n    \"\"\"\n    # Verify and set sequence mask column.\n    # TODO(aarg): Add support for `NumericColumn`.\n    if not isinstance(sequence_length_mask, six.string_types):\n      raise TypeError('`sequence_mask` column must be a string. '\n                      'Given type: {}.'.format(type(sequence_length_mask)))\n    self._sequence_length_mask = sequence_length_mask\n\n    # Verify and set feature columns (to be flattened).\n    feature_columns = feature_columns or []\n    if not isinstance(feature_columns, Iterable):\n      raise TypeError('`feature_columns` must be either a string or an '\n                      'iterable of strings got {} instead.'.format(\n                          type(feature_columns)))\n    if isinstance(feature_columns, six.string_types):\n      self._feature_columns = [feature_columns]\n    else:\n      self._feature_columns = feature_columns\n\n    for column in self._feature_columns:\n      # TODO(aarg): Add support for `NumericColumn` and `SequenceNumericColumn`.\n      if not isinstance(column, six.string_types):\n        raise TypeError('Column must a string. Given type: {}.'.format(\n            type(column)))\n\n    # Set other variables.\n    if isinstance(static_head, multi_head.MultiHead):\n      # TODO(aarg): Add support for MultiHead.\n      raise ValueError(\n          '`MultiHead` is not supported with `SequentialHeadWrapper`.')\n    self._static_head = static_head\n\n    super(SequentialHeadWrapper, self).__init__()\n\n  def _flatten(self, labels, logits, features):\n    \"\"\"Flattens labels, logits, and features tensors.\n\n    Provided tensors need to have at least two dimensions. The two first\n    dimensions of the provided tensors are flattened to one single dimension.\n    If a tensor is dense, the sequence mask in the features dictionary is used\n    to flatten it.\n\n    Note: If indices of a sparse tensor are not sorted, they will be reordered.\n\n    Args:\n      labels: `Tensor` or `SparseTensor` to flatten.\n      logits: `Tensor` or `SparseTensor` to flatten.\n      features: Dictionary of `Tensor` or `SparseTensor` objects to flatten.\n\n    Returns:\n      - Dense `Tensor` with flattened labels.\n      - Dense `Tensor` with flattened logits.\n      - Dictionary of flattened dense `Tensor` objects.\n\n    Raises:\n      ValueError: If the sequence mask is not found in `features`.\n      ValueError: If one of the provided tensors to flatten has not at least two\n        dimensions.\n    \"\"\"\n    # Retrieve sequence_mask from features dictionary.\n    if self.input_sequence_mask_key not in features:\n      raise ValueError('The provided sequence_length_mask key `{}` should be '\n                       'included in the features dictionary, but was not '\n                       'found. Found keys: {}.'.format(\n                           self.input_sequence_mask_key, list(features.keys())))\n    sequence_mask = features[self.input_sequence_mask_key]\n    if sequence_mask.get_shape().ndims != 2:\n      raise ValueError('Mask is expected to have two dimensions, got '\n                       '{} instead.'.format(sequence_mask.get_shape().ndims))\n\n    with ops.name_scope('flatten'):\n      expected_length = tf.math.reduce_sum(\n          tf.cast(sequence_mask, tf.dtypes.int32))\n      # Flatten logits and labels.\n      flat_logits = _flatten_tensor(logits, sequence_mask, expected_length)\n      flat_labels = _flatten_tensor(labels, sequence_mask, expected_length)\n\n      # Flatten features.\n      flat_features = {}\n      for column in self._feature_columns:\n        if column not in features:\n          raise ValueError('`{}` column expected in features '\n                           'dictionary.'.format(column))\n        flat_features[column] = _flatten_tensor(features[column], sequence_mask,\n                                                expected_length)\n\n      return flat_labels, flat_logits, flat_features\n\n  def loss(self,\n           logits,\n           labels,\n           features=None,\n           mode=None,\n           regularization_losses=None):\n    \"\"\"Flattens input and returns regularized training loss.\n\n    Flattens `logits`, `labels`, and `features` tensors that are specified by\n    the head's `feature_columns` before calling the static head's `loss` method.\n\n    Args:\n      logits: Logits `Tensor` of rank >= 2 and shape [batch_size, seq_length,\n        D2, ... DN].\n      labels: Labels `Tensor` or `SparseTensor` or rank >= 2 and shape\n        [batch_size, seq_length, D2, ... DN].\n      features: Input `dict` mapping string feature names to `Tensor` or\n        `SparseTensor` objects containing the values for that feature in a\n        minibatch. Must contain the sequence length mask tensor. Features\n        corresponding to the sequential's head `feature_columns` are flattened\n        and passed to the static head's `loss` method.\n      mode: Estimator's `ModeKeys`. To be used in case loss calculation is\n        different in Train and Eval mode.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses.\n\n    Returns:\n      A scalar `Tensor` representing regularized training loss used in train and\n      eval.\n    \"\"\"\n    flat_labels, flat_logits, flat_features = self._flatten(\n        labels, logits, features)\n    return self._static_head.loss(\n        logits=flat_logits,\n        labels=flat_labels,\n        features=flat_features,\n        mode=mode,\n        regularization_losses=regularization_losses)\n\n  def create_estimator_spec(self,\n                            features,\n                            mode,\n                            logits,\n                            labels=None,\n                            optimizer=None,\n                            trainable_variables=None,\n                            train_op_fn=None,\n                            update_ops=None,\n                            regularization_losses=None):\n    \"\"\"Returns `EstimatorSpec` that a model_fn can return.\n\n    If in TRAIN or EVAL mode, `logits`, `labels`, and `features` tensors\n    corresponding to the head's `feature_columns` are flattened before calling\n    the static head's `create_estimator_spec` method.\n    If in PREDICT mode, no flattening is done. The `EstimatatorSpec` is computed\n    using the static head's `create_estimator_spec` method. The sequence length\n    mask tensor is added to the predictions dictionary.\n\n    Args:\n      features: Input `dict` mapping string feature names to `Tensor` or\n        `SparseTensor` objects containing the values for that feature in a\n        minibatch. If in TRAIN or EVAL mode, only specified features are\n        flattened and passed to the static head's method.\n      mode: Estimator's `ModeKeys`.\n      logits: Logits `Tensor` of rank >= 2 and shape [batch_size, seq_length,\n        D2, ... DN].\n      labels: Labels `Tensor` or `SparseTensor` or rank >= 2 and shape\n        [batch_size, seq_length, D2, ... DN].\n      optimizer: An `tf_keras.optimizers.Optimizer` instance to optimize the\n        loss in TRAIN mode. Namely, sets\n        `train_op = optimizer.get_updates(loss, trainable_variables)`, which\n        updates variables to minimize `loss`.\n      trainable_variables: A list or tuple of `Variable` objects to update to\n        minimize `loss`. In Tensorflow 1.x, by default these are the list of\n        variables collected in the graph under the key\n        `GraphKeys.TRAINABLE_VARIABLES`. As Tensorflow 2.x doesn't have\n        collections and GraphKeys, trainable_variables need to be passed\n        explicitly here.\n      train_op_fn: Function that takes a scalar loss `Tensor` and returns an op\n        to optimize the model with the loss in TRAIN mode. Used if `optimizer`\n        is `None`. Exactly one of `train_op_fn` and `optimizer` must be set in\n        TRAIN mode. By default, it is `None` in other modes. If you want to\n        optimize loss yourself, you can pass `lambda _: tf.no_op()` and then use\n          `EstimatorSpec.loss` to compute and apply gradients.\n      update_ops: A list or tuple of update ops to be run at training time. For\n        example, layers such as BatchNormalization create mean and variance\n        update ops that need to be run at training time. In Tensorflow 1.x,\n        these are thrown into an UPDATE_OPS collection. As Tensorflow 2.x\n        doesn't have collections, update_ops need to be passed explicitly here.\n      regularization_losses: A list of additional scalar losses to be added to\n        the training loss, such as regularization losses.\n\n    Returns:\n      `EstimatorSpec`.\n    \"\"\"\n    if mode == ModeKeys.PREDICT:\n      spec = self._static_head.create_estimator_spec(\n          features=features, mode=mode, logits=logits)\n      spec.predictions[self.input_sequence_mask_key] = features[\n          self.input_sequence_mask_key]\n      return spec._replace(predictions=spec.predictions)\n\n    flat_labels, flat_logits, flat_features = self._flatten(\n        labels, logits, features)\n\n    return self._static_head.create_estimator_spec(\n        features=flat_features,\n        mode=mode,\n        logits=flat_logits,\n        trainable_variables=trainable_variables,\n        labels=flat_labels,\n        optimizer=optimizer,\n        train_op_fn=train_op_fn,\n        regularization_losses=regularization_losses,\n        update_ops=update_ops)\n\n  def update_metrics(self,\n                     eval_metrics,\n                     features,\n                     logits,\n                     labels,\n                     regularization_losses=None):\n    \"\"\"Updates metric objects and returns a `dict` of the updated metrics.\n\n    Flattens `logits`, `labels`, and `features` tensors that are specified by\n    the head's feature_columns` before calling the static head's\n    `update_metrics` method.\n\n    Args:\n      eval_metrics: A `dict` of metrics to be updated.\n      features: Input `dict` mapping string feature names to `Tensor` or\n        `SparseTensor` objects containing the values for that feature in a\n        minibatch. Only specified features are flattened and passed to the\n        static head's method.\n      logits: Logits `Tensor` of rank >= 2 and shape [batch_size, seq_length,\n        D2, ... DN].\n      labels: Labels `Tensor` or `SparseTensor` or rank >= 2 and shape\n        [batch_size, seq_length, D2, ... DN].\n      regularization_losses: A list of additional scalar losses to be added to\n        the training and evaluation loss, such as regularization losses.\n\n    Returns:\n       A `dict` of updated metrics keyed by name. The value is an instance of\n       `Metric` class.\n    \"\"\"\n    flat_labels, flat_logits, flat_features = self._flatten(\n        labels, logits, features)\n    return self._static_head.update_metrics(\n        eval_metrics=eval_metrics,\n        features=flat_features,\n        logits=flat_logits,\n        labels=flat_labels,\n        regularization_losses=regularization_losses)\n\n  def _create_tpu_estimator_spec(self,\n                                 features,\n                                 mode,\n                                 logits,\n                                 labels=None,\n                                 optimizer=None,\n                                 trainable_variables=None,\n                                 train_op_fn=None,\n                                 update_ops=None,\n                                 regularization_losses=None):\n    raise NotImplementedError\n\n  def predictions(self, logits, keys=None):\n    \"\"\"Calls the static head's `predictions` method.\"\"\"\n    return self._static_head.predictions(logits, keys=keys)\n\n  def metrics(self, regularization_losses=None):\n    \"\"\"Calls the static head's `metrics` method.\"\"\"\n    return self._static_head.metrics(regularization_losses)\n\n  @property\n  def input_sequence_mask_key(self):\n    \"\"\"Returns the key for the sequence mask feature.\"\"\"\n    return self._sequence_length_mask\n\n  @property\n  def logits_dimension(self):\n    \"\"\"Returns the logits dimension of the static head.\"\"\"\n    return self._static_head.logits_dimension\n\n  @property\n  def loss_reduction(self):\n    \"\"\"Returns the loss reduction of the static head.\"\"\"\n    return self._static_head.loss_reduction\n\n  @property\n  def name(self):\n    \"\"\"Returns the name of the static head.\"\"\"\n    if self._static_head.name:\n      return '{}_sequential'.format(self._static_head.name)\n    return None\n\n  @property\n  def static_head(self):\n    \"\"\"Returns the wrapped static head.\"\"\"\n    return self._static_head\n\n\ndef _flatten_tensor(tensor, sequence_mask, expected_length):\n  \"\"\"Flattens the two first dimensions and reshapes a tensor or sparse tensor.\n\n  If `tensor` is a dense tensor, the sequence_mask is used to infer valid\n  inputs.\n\n  Note: If `tensor` is a `SparseTensor` and the indices are not sorted, they\n  will be reordered.\n\n  Args:\n    tensor: A `Tensor` or `SparseTensor` of dimension at least 2, of shape\n      [batch_size, seq_length, D0, D1, ..., DN].\n    sequence_mask: A boolean `Tensor` of shape [batch_size, seq_length].\n    expected_length: A integer scalar `Tensor` with the expected length of the\n      resulting flattenned Tensor.\n\n  Returns:\n    A `Tensor` object of shape [expected_length, D0, D1, ..., DN].\n\n  Raises:\n    ValueError: If `tensor` has not at least 2 dimensions.\n    ValueError: If `tensor` is not a `Tensor` or `SparseTensor` object.\n    InvalidArgumentError: If the resulting `Tensor` doesn't have the expected\n      length.\n  \"\"\"\n  shape = tensor.get_shape()\n  if shape.ndims < 2:\n    raise ValueError('Input tensor expected to have at least 2 dimensions, '\n                     'got {} instead.'.format(shape.ndims))\n  if isinstance(tensor, tf.sparse.SparseTensor):\n    # What follows depends on the indices ordering. Hence we reorder the indices\n    # to ensure correctness.\n    flat_tensor = tf.sparse.reorder(tensor).values\n    if shape.ndims > 2:\n      new_shape = tf.concat([[-1], shape[2:]], axis=0)\n      flat_tensor = tf.reshape(tensor.values, new_shape)\n  elif isinstance(tensor, tf.Tensor):\n    flat_tensor = tf.boolean_mask(tensor, sequence_mask)\n  else:\n    raise ValueError('`tensor` expected to be a `Tensor` or  `SparseTensor` '\n                     'got `{}` instead.'.format(tensor))\n  if shape.ndims == 2:\n    flat_tensor = tf.compat.v1.expand_dims(flat_tensor, -1)\n    expected_shape = tf.concat([[expected_length], [1]], axis=0)\n  else:\n    expected_shape = tf.concat([[expected_length], shape[2:]], axis=0)\n\n  # TODO(b/119617064): Unify eager and graph implementations.\n  err_message = 'Tensor shape is incompatible with provided mask.'\n  if tf.executing_eagerly():\n    if flat_tensor._shape_tuple() != tuple(expected_shape.numpy()):  # pylint: disable=protected-access\n      raise ValueError(err_message)\n    return flat_tensor\n  with tf.control_dependencies([\n      tf.compat.v1.debugging.assert_equal(\n          tf.compat.v1.shape(flat_tensor), expected_shape, message=err_message)\n  ]):\n    return tf.identity(flat_tensor)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/head/sequential_head_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for sequential_head.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.canned import metric_keys\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.head import binary_class_head as binary_head_lib\nfrom tensorflow_estimator.python.estimator.head import head_utils as test_lib\nfrom tensorflow_estimator.python.estimator.head import multi_class_head as multi_head_lib\nfrom tensorflow_estimator.python.estimator.head import multi_head\nfrom tensorflow_estimator.python.estimator.head import sequential_head as seq_head_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\ndef _convert_to_tensor(features):\n  \"\"\"Converts an arrays or dict of arrays to tensors or dict of tensors.\"\"\"\n  if isinstance(features, dict):\n    if set(features.keys()) == set(['indices', 'values', 'dense_shape']):\n      return tf.sparse.SparseTensor(**features)\n    for col in features:\n      features[col] = _convert_to_tensor(features[col])\n    return features\n  return ops.convert_to_tensor(features)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass TestFlatten(tf.test.TestCase, parameterized.TestCase):\n  \"\"\"Tests flatten functions.\"\"\"\n\n  @parameterized.named_parameters(\n      {\n          'testcase_name': 'one_dim_sparse_tensor',\n          'tensor': {\n              'indices': ((0, 0), (0, 1), (1, 0)),\n              'values': (1, 2, 3),\n              'dense_shape': (2, 2)\n          },\n          'expected': [[1], [2], [3]]\n      }, {\n          'testcase_name': 'multi_dim_sparse_tensor',\n          'tensor': {\n              'indices': ((0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 0, 0),\n                          (1, 0, 1)),\n              'values': (1, 2, 3, 4, 5, 6),\n              'dense_shape': (2, 2, 2)\n          },\n          'expected': [[1, 2], [3, 4], [5, 6]]\n      }, {\n          'testcase_name': 'one_dim_dense_tensor',\n          'tensor': [[1, 2], [3, 4]],\n          'expected': [[1], [2], [3]]\n      }, {\n          'testcase_name': 'multi_dim_dense_tensor',\n          'tensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],\n          'expected': [[1, 2], [3, 4], [5, 6]]\n      }, {\n          'testcase_name': 'unsorted_sparse_indices',\n          'tensor': {\n              'indices': ((0, 0), (1, 0), (0, 1)),\n              'values': (1, 3, 2),\n              'dense_shape': (2, 2)\n          },\n          'expected': [[1], [2], [3]]\n      })\n  def test_flatten_tensor(self, tensor, expected):\n    \"\"\"Tests the output of the `_flatten_tensor` function.\n\n    Args:\n      tensor: Dense or sparse array.\n      expected: Array with expected output of `_flatten_tensor`.\n    \"\"\"\n    sequence_mask = np.array([[1, 1], [1, 0]])\n    tensor = _convert_to_tensor(tensor)\n    flat_tensor = seq_head_lib._flatten_tensor(\n        tensor, sequence_mask, expected_length=sequence_mask.sum())\n    if tf.executing_eagerly():\n      self.assertAllEqual(flat_tensor, expected)\n      return\n    with self.cached_session() as sess:\n      self.assertAllEqual(sess.run(flat_tensor), expected)\n\n  def _test_flatten_method(self, features, feature_columns):\n    \"\"\"Runs seq head's `_flatten` method and returns output for testing.\"\"\"\n    head = seq_head_lib.SequentialHeadWrapper(\n        static_head=None,\n        sequence_length_mask='sequence_mask',\n        feature_columns=feature_columns)\n    labels = {\n        'indices': ((0, 0), (0, 1), (1, 0)),\n        'values': (1, 2, 3),\n        'dense_shape': (2, 2)\n    }\n    logits = np.array([[[10], [11]], [[12], [13]]])\n\n    features = _convert_to_tensor(features)\n    labels = tf.sparse.SparseTensor(**labels)\n    logits = ops.convert_to_tensor(logits)\n    output = head._flatten(labels, logits, features)\n    if tf.executing_eagerly():\n      return output\n    with self.cached_session() as sess:\n      return sess.run(output)\n\n  def test_flatten_method(self):\n    \"\"\"Tests output of `_flatten` method.\"\"\"\n    features = {'sequence_mask': np.array([[1, 1], [1, 0]])}\n    expected_output = ([[1], [2], [3]], [[10], [11], [12]], {})\n    output = self._test_flatten_method(features, feature_columns=[])\n    self.assertAllClose(expected_output, output)\n\n  def test_flatten_with_one_feature_columns(self):\n    \"\"\"Tests output of `_flatten` method with one feature column provided.\"\"\"\n    features = {\n        'sequence_mask': np.array([[1, 1], [1, 0]]),\n        'weights': np.array([[0.5, 0.5], [1., 0]])\n    }\n    expected_output = ([[1], [2], [3]], [[10], [11], [12]], {\n        'weights': np.array([[0.5], [0.5], [1.]])\n    })\n    output = self._test_flatten_method(features, feature_columns='weights')\n    self.assertAllClose(expected_output, output)\n\n  def test_flatten_with_multiple_feature_columns(self):\n    \"\"\"Tests `_flatten` method with multiple feature columns provided.\"\"\"\n    features = {\n        'sequence_mask': np.array([[1, 1], [1, 0]]),\n        'a': np.array([[0.5, 0.5], [1., 0]]),\n        'b': np.array([[1.5, 1.5], [2., 0]])\n    }\n    expected_output = ([[1], [2], [3]], [[10], [11], [12]], {\n        'a': np.array([[0.5], [0.5], [1.]]),\n        'b': np.array([[1.5], [1.5], [2.]])\n    })\n    output = self._test_flatten_method(features, feature_columns=['a', 'b'])\n    self.assertAllClose(expected_output, output)\n\n  def test_flatten_no_mask(self):\n    \"\"\"Tests error in `_flatten` method when sequence mask is not provided.\"\"\"\n    features = {}\n    with self.assertRaisesRegexp(\n        ValueError, (r'The provided sequence_length_mask key `sequence_mask` '\n                     r'should be included in.* Found keys: \\[\\].')):\n      _ = self._test_flatten_method(features, feature_columns=[])\n\n  def test_flatten_missing_feature(self):\n    \"\"\"Tests error in `_flatten` method when feature is not provided.\"\"\"\n    features = {'sequence_mask': np.array([[1, 1], [1, 0]])}\n    with self.assertRaisesRegexp(\n        ValueError, '`weights` column expected in features dictionary.'):\n      _ = self._test_flatten_method(features, feature_columns=['weights'])\n\n  def test_flatten_tensor_wrong_feature_dim(self):\n    \"\"\"Tests `_flatten` method when feature has wrong dimension.\"\"\"\n    features = {\n        'sequence_mask': np.array([[1, 1], [1, 0]]),\n        'weights': np.array([0.5, 0.5, 1., 0])\n    }\n    with self.assertRaisesRegexp(\n        ValueError, 'Input tensor expected to have at least 2 dimensions.'):\n      _ = self._test_flatten_method(features, feature_columns=['weights'])\n\n  def test_flatten_tensor_wrong_feature_mask(self):\n    \"\"\"Tests `_flatten` with feature mask different from provided mask.\"\"\"\n    features = {'sequence_mask': np.array([[1, 1], [1, 1]])}\n    error = (\n        ValueError\n        if tf.executing_eagerly() else tf.errors.InvalidArgumentError)\n    with self.assertRaisesRegexp(\n        error, 'Tensor shape is incompatible with provided mask.'):\n      _ = self._test_flatten_method(features, feature_columns=[])\n\n  def test_flatten_tensor_wrong_mask_dim(self):\n    \"\"\"Tests `_flatten` with mask that has wrong dimensions.\"\"\"\n    features = {'sequence_mask': np.array([1, 1])}\n    with self.assertRaisesRegexp(\n        ValueError, 'Mask is expected to have two dimensions, got .* instead.'):\n      _ = self._test_flatten_method(features, feature_columns=[])\n\n\nclass _MockHead(object):\n  \"\"\"A static head to be wrapped in a sequential head, for testing.\"\"\"\n\n  def metrics(self, regularization_losses=None):\n    return regularization_losses\n\n  def loss(self, **kwargs):\n    return kwargs\n\n  def create_estimator_spec(self, **kwargs):\n    Spec = collections.namedtuple('Spec', ['predictions', 'kwargs'])  # pylint: disable=invalid-name\n    return Spec(predictions={}, kwargs=kwargs)\n\n\n@test_util.run_all_in_graph_and_eager_modes\nclass TestSequentialHead(tf.test.TestCase):\n  \"\"\"Tests sequential head methods.\"\"\"\n\n  def _assert_equal(self, d, dref, session=None):\n    \"\"\"Recursively checks that all items of a dictionary are close.\n\n    Dictionary can contain numerical values, `Tensor` objects or dictionaries of\n    the former.\n\n    If an item is a `Tensor`, its value is evaluated then compared to the\n    reference.\n\n    Args:\n      d: Dictionary to check.\n      dref: Dictionary to use as a reference for checks.\n      session: A `tf.Session` object.\n    \"\"\"\n    for key, ref_item in dref.items():\n      if isinstance(ref_item, dict):\n        self._assert_equal(d[key], dref=ref_item, session=session)\n      elif isinstance(d[key], tf.Tensor):\n        self.assertAllClose(\n            session.run(d[key]) if session else d[key], ref_item)\n      else:\n        self.assertEqual(d[key], ref_item)\n\n  def test_predictions(self):\n    \"\"\"Tests predictions output.\n\n    Use `predictions` method in eager execution, else `create_estimator_spec` in\n    PREDICT mode.\n\n    logits = [[0.3, -0.4], [0.2, 0.2]]\n    logistics = 1 / (1 + exp(-logits))\n              = [[0.57, 0.40], [0.55, 0.55]]\n    \"\"\"\n    head = seq_head_lib.SequentialHeadWrapper(binary_head_lib.BinaryClassHead(),\n                                              'sequence_mask')\n\n    logits = [[[0.3], [-0.4]], [[0.2], [0.2]]]\n    expected_logistics = [[[0.574443], [0.401312]], [[0.549834], [0.549834]]]\n\n    features = {\n        'sequence_mask': ops.convert_to_tensor(np.array([[1, 1], [1, 0]]))\n    }\n\n    keys = prediction_keys.PredictionKeys\n    if tf.executing_eagerly():\n      predictions = head.predictions(\n          logits=logits, keys=[keys.LOGITS, keys.LOGISTIC])\n      self.assertItemsEqual(predictions.keys(), [keys.LOGITS, keys.LOGISTIC])\n      self.assertAllClose(logits, predictions[keys.LOGITS])\n      self.assertAllClose(expected_logistics, predictions[keys.LOGISTIC])\n      return\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.PREDICT,\n        logits=logits,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n    self.assertIn('sequence_mask', spec.predictions)\n    with self.cached_session() as sess:\n      self.assertAllEqual(\n          sess.run(spec.predictions['sequence_mask']),\n          features['sequence_mask'])\n      self.assertAllClose(logits, sess.run(spec.predictions[keys.LOGITS]))\n      self.assertAllClose(expected_logistics,\n                          sess.run(spec.predictions[keys.LOGISTIC]))\n\n  def test_metrics(self):\n    \"\"\"Tests the `metrics` method.\n\n    Tests that:\n    - Returned metrics match the returned metrics of the static head.\n    - `regularization_losses` argument is properly passed to the static head's\n      method.\n    \"\"\"\n    head = seq_head_lib.SequentialHeadWrapper(binary_head_lib.BinaryClassHead(),\n                                              'mask')\n    metrics = head.metrics(regularization_losses=2.5)\n    keys = metric_keys.MetricKeys\n    self.assertIn(keys.ACCURACY, metrics)\n    self.assertIn(keys.LOSS_REGULARIZATION, metrics)\n\n  def test_loss_args(self):\n    \"\"\"Tests that variables are flattened and passed to static head's method.\"\"\"\n    logits = [[1, 2], [3, 4]]\n    labels = [[0, 1], [0, 2]]\n    features = {'weights': [[0.3, 0.2], [0.5, 100]], 'mask': [[1, 1], [1, 0]]}\n    head = seq_head_lib.SequentialHeadWrapper(_MockHead(), 'mask', 'weights')\n    expected_output = {\n        'logits': [[1], [2], [3]],\n        'labels': [[0], [1], [0]],\n        'features': {\n            'weights': [[0.3], [0.2], [0.5]]\n        },\n        'mode': 'my-mode',\n        'regularization_losses': 123\n    }\n    output = head.loss(\n        logits=_convert_to_tensor(logits),\n        labels=_convert_to_tensor(labels),\n        features=_convert_to_tensor(features),\n        mode='my-mode',\n        regularization_losses=123)\n    with self.cached_session() as sess:\n      self._assert_equal(output, dref=expected_output, session=sess)\n\n  def test_create_estimator_spec_args(self):\n    \"\"\"Tests that variables are flattened and passed to static head's method.\"\"\"\n    logits = [[1, 2], [3, 4]]\n    labels = [[0, 1], [0, 2]]\n    features = {'weights': [[0.3, 0.2], [0.5, 100]], 'mask': [[1, 1], [1, 0]]}\n    head = seq_head_lib.SequentialHeadWrapper(_MockHead(), 'mask', 'weights')\n    w = tf.Variable(1)\n    update_op = w.assign_add(1)\n    trainable_variables = [tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)]\n    expected_output = {\n        'logits': [[1], [2], [3]],\n        'labels': [[0], [1], [0]],\n        'features': {\n            'weights': [[0.3], [0.2], [0.5]]\n        },\n        'mode': ModeKeys.TRAIN,\n        'regularization_losses': 123,\n        'optimizer': 'my-opt',\n        'train_op_fn': 'my-train-op',\n        'trainable_variables': trainable_variables,\n        'update_ops': [update_op]\n    }\n    spec = head.create_estimator_spec(\n        logits=_convert_to_tensor(logits),\n        labels=_convert_to_tensor(labels),\n        features=_convert_to_tensor(features),\n        mode=ModeKeys.TRAIN,\n        optimizer='my-opt',\n        train_op_fn='my-train-op',\n        regularization_losses=123,\n        update_ops=[update_op],\n        trainable_variables=trainable_variables)\n    with self.cached_session() as sess:\n      self.assertItemsEqual(spec.kwargs.keys(), expected_output.keys())\n      self._assert_equal(spec.kwargs, dref=expected_output, session=sess)\n\n  def test_head_properties(self):\n    \"\"\"Tests that the head's properties are correcly implemented.\"\"\"\n    static_head = binary_head_lib.BinaryClassHead(\n        loss_reduction=tf.losses.Reduction.SUM, name='a_static_head')\n    head = seq_head_lib.SequentialHeadWrapper(static_head,\n                                              'a_sequence_mask_col')\n    self.assertEqual(head.name, 'a_static_head_sequential')\n    self.assertEqual(head.logits_dimension, 1)\n    self.assertEqual(head.loss_reduction, tf.losses.Reduction.SUM)\n    self.assertEqual(head.input_sequence_mask_key, 'a_sequence_mask_col')\n    self.assertEqual(head.static_head.name, 'a_static_head')\n\n  def test_loss_reduction(self):\n    \"\"\"Tests loss reduction.\n\n    Use `loss` method in eager execution, else `create_estimator_spec` in TRAIN\n    mode.\n\n    logits = [[[2., 3., 4.], [5., -0.5, 0.]],\n              [[-1.0, 2.0, 0.5], [_]]],\n    labels = [[0, 1],\n              [2, _]]\n    weights = [[0.5, 0.2],\n               [0.3, _]]\n    loss = [0.5*2.40 + 0.2*5.41 + 0.3*1.74] / 3 = 0.94\n    \"\"\"\n    static_head = multi_head_lib.MultiClassHead(\n        n_classes=3, weight_column='weights')\n    head = seq_head_lib.SequentialHeadWrapper(static_head, 'sequence_mask',\n                                              'weights')\n    expected_loss = 0.942783\n    features = {\n        'weights':\n            tf.sparse.SparseTensor(\n                indices=((0, 0), (0, 1), (1, 0)),\n                values=(0.5, 0.2, 0.3),\n                dense_shape=(2, 2)),\n        'sequence_mask':\n            ops.convert_to_tensor([[1, 1], [1, 0]])\n    }\n    logits = ops.convert_to_tensor([[[2., 3., 4.], [5., -0.5, 0.]],\n                                    [[-1.0, 2.0, 0.5], [1.0, 0.5, 2.0]]])\n    labels = tf.sparse.SparseTensor(\n        indices=((0, 0), (0, 1), (1, 0)), values=(0, 1, 2), dense_shape=(2, 2))\n\n    class _Optimizer(tf_keras.optimizers.Optimizer):\n\n      def get_updates(self, loss, params):\n        del params, loss\n        return [tf.constant('op')]\n\n      def get_config(self):\n        config = super(_Optimizer, self).get_config()\n        return config\n\n    if tf.executing_eagerly():\n      loss = head.loss(logits=logits, labels=labels, features=features)\n    else:\n      spec = head.create_estimator_spec(\n          features,\n          ModeKeys.TRAIN,\n          logits,\n          labels=labels,\n          optimizer=_Optimizer('my_optimizer'),\n          trainable_variables=[\n              tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)\n          ])\n      with self.cached_session() as sess:\n        loss = sess.run(spec.loss)\n    self.assertAllClose(loss, expected_loss, atol=1e-4)\n\n  def test_metrics_computation(self):\n    \"\"\"Runs metrics computation tests.\n\n    Use `update_metrics` method in eager execution, else `create_estimator_spec`\n    in EVAL mode.\n\n    logits = [[-101, 102, -103], [104, _, _]]\n    predicted_labels = [[0, 1, 0], [1, _, _]]\n    labels = [[1, 1, 1], [1, _, _]]\n    weights = [[2, 5, 1], [2, _, _]]\n\n    loss = (101*2 + 103*1) / 10 = 30.5\n    accuracy = (0 + 5 + 0 + 2) / (2 + 5 + 1 + 2) = 0.7\n    prediction_mean = (0 + 5 + 0 + 2) / (2 + 5 + 1 + 2) = 0.7\n    precision = (5 + 2) / (5 + 2) = 1.0\n    recall = (5 + 2) / (2 + 5 + 1 + 2) = 0.7\n    \"\"\"\n    static_head = binary_head_lib.BinaryClassHead(weight_column='weights')\n    head = seq_head_lib.SequentialHeadWrapper(static_head, 'sequence_mask',\n                                              'weights')\n\n    features = {\n        'sequence_mask': np.array([[1, 1, 1], [1, 0, 0]]),\n        'weights': np.array([[2, 5, 1], [2, 100, 100]])\n    }\n    regularization_losses = [100.]\n    logits = _convert_to_tensor([[-101, 102, -103], [104, 100, 100]])\n    labels = tf.sparse.SparseTensor(\n        values=[1, 1, 1, 1],\n        indices=((0, 0), (0, 1), (0, 2), (1, 0)),\n        dense_shape=(2, 3))\n    features = _convert_to_tensor(features)\n    expected_loss = 30.5\n    keys = metric_keys.MetricKeys\n    expected_metrics = {\n        keys.LOSS_MEAN: expected_loss,\n        keys.ACCURACY: 0.7,\n        keys.PREDICTION_MEAN: 0.7,\n        keys.LABEL_MEAN: 1.0,\n        keys.LOSS_REGULARIZATION: 100,\n        keys.PRECISION: 1.0,\n        keys.RECALL: 0.7,\n        keys.ACCURACY_BASELINE: 1.0,\n        keys.AUC: 0.,\n        keys.AUC_PR: 1.0\n    }\n\n    if tf.executing_eagerly():\n      eval_metrics = head.metrics(regularization_losses=regularization_losses)\n      updated_metrics = head.update_metrics(eval_metrics, features, logits,\n                                            labels, regularization_losses)\n      self.assertItemsEqual(expected_metrics.keys(), updated_metrics.keys())\n      self.assertAllClose(\n          expected_metrics,\n          {k: updated_metrics[k].result() for k in updated_metrics})\n      return\n\n    spec = head.create_estimator_spec(\n        features=features,\n        mode=ModeKeys.EVAL,\n        logits=logits,\n        labels=labels,\n        regularization_losses=regularization_losses,\n        trainable_variables=[tf.Variable([1.0, 2.0], dtype=tf.dtypes.float32)])\n\n    with self.cached_session() as sess:\n      test_lib._initialize_variables(self, spec.scaffold)\n      self.assertIsNone(spec.scaffold.summary_op)\n      value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}\n      update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}\n      _ = sess.run(update_ops)\n      self.assertAllClose(expected_metrics,\n                          {k: value_ops[k].eval() for k in value_ops})\n\n  def test_wrong_mask_type(self):\n    \"\"\"Tests error raised when the mask doesn't have proper type.\"\"\"\n    with self.assertRaisesRegexp(TypeError,\n                                 '`sequence_mask` column must be a string.'):\n      _ = seq_head_lib.SequentialHeadWrapper(None, sequence_length_mask=1)\n\n  def test_wrong_feature_column_type(self):\n    \"\"\"Tests error raised when the feature column doesn't have proper type.\"\"\"\n    with self.assertRaisesRegexp(\n        TypeError, '`feature_columns` must be either a string or an iterable'):\n      _ = seq_head_lib.SequentialHeadWrapper(None, 'mask', feature_columns=1)\n\n  def test_wrong_feature_column_type_in_iterable(self):\n    \"\"\"Tests error raised when the feature column doesn't have proper type.\"\"\"\n    with self.assertRaisesRegexp(TypeError,\n                                 'Column must a string. Given type: .*.'):\n      _ = seq_head_lib.SequentialHeadWrapper(None, 'mask', feature_columns=[1])\n\n  def test_multi_head_provided(self):\n    \"\"\"Tests error raised when a multi-head is provided.\"\"\"\n    with self.assertRaisesRegexp(\n        ValueError,\n        '`MultiHead` is not supported with `SequentialHeadWrapper`.'):\n      _ = seq_head_lib.SequentialHeadWrapper(\n          multi_head.MultiHead(\n              [binary_head_lib.BinaryClassHead(name='test-head')]))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/hooks/__init__.py",
    "content": ""
  },
  {
    "path": "tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Some common SessionRunHook classes.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom tensorflow.python.training.basic_session_run_hooks import CheckpointSaverHook\nfrom tensorflow.python.training.basic_session_run_hooks import CheckpointSaverListener\nfrom tensorflow.python.training.basic_session_run_hooks import FeedFnHook\nfrom tensorflow.python.training.basic_session_run_hooks import FinalOpsHook\nfrom tensorflow.python.training.basic_session_run_hooks import GlobalStepWaiterHook\nfrom tensorflow.python.training.basic_session_run_hooks import LoggingTensorHook\nfrom tensorflow.python.training.basic_session_run_hooks import NanLossDuringTrainingError\nfrom tensorflow.python.training.basic_session_run_hooks import NanTensorHook\nfrom tensorflow.python.training.basic_session_run_hooks import ProfilerHook\nfrom tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer\nfrom tensorflow.python.training.basic_session_run_hooks import StepCounterHook\nfrom tensorflow.python.training.basic_session_run_hooks import StopAtStepHook\nfrom tensorflow.python.training.basic_session_run_hooks import SummarySaverHook\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\nestimator_export(\"estimator.SecondOrStepTimer\")(SecondOrStepTimer)\nestimator_export(\"estimator.LoggingTensorHook\")(LoggingTensorHook)\nestimator_export(\"estimator.StopAtStepHook\")(StopAtStepHook)\nestimator_export(\"estimator.CheckpointSaverListener\")(CheckpointSaverListener)\nestimator_export(\"estimator.CheckpointSaverHook\")(CheckpointSaverHook)\nestimator_export(\"estimator.StepCounterHook\")(StepCounterHook)\nestimator_export(\"estimator.NanLossDuringTrainingError\")(\n    NanLossDuringTrainingError)\nestimator_export(\"estimator.NanTensorHook\")(NanTensorHook)\nestimator_export(\"estimator.SummarySaverHook\")(SummarySaverHook)\nestimator_export(\"estimator.GlobalStepWaiterHook\")(GlobalStepWaiterHook)\nestimator_export(\"estimator.FinalOpsHook\")(FinalOpsHook)\nestimator_export(\"estimator.FeedFnHook\")(FeedFnHook)\nestimator_export(\"estimator.ProfilerHook\")(ProfilerHook)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/hooks/basic_session_run_hooks_test.py",
    "content": "# pylint: disable=g-bad-file-header\n# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for basic_session_run_hooks.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os.path\nimport shutil\nimport tempfile\nimport time\nimport tensorflow as tf\nfrom tensorflow.python.framework import meta_graph\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.platform import tf_logging\nfrom tensorflow.python.training import monitored_session\nfrom tensorflow.python.training import training_util\nfrom tensorflow_estimator.python.estimator.hooks import basic_session_run_hooks\nfrom tensorflow_estimator.python.estimator.hooks import fake_summary_writer\n\n# Provide a realistic start time for unit tests where we need to mock out\n# calls to time.time().\nMOCK_START_TIME = 1484695987.209386\n\n\nclass MockCheckpointSaverListener(\n    basic_session_run_hooks.CheckpointSaverListener):\n\n  def __init__(self):\n    self.begin_count = 0\n    self.before_save_count = 0\n    self.after_save_count = 0\n    self.end_count = 0\n    self.ask_for_stop = False\n\n  def begin(self):\n    self.begin_count += 1\n\n  def before_save(self, session, global_step):\n    self.before_save_count += 1\n\n  def after_save(self, session, global_step):\n    self.after_save_count += 1\n    if self.ask_for_stop:\n      return True\n\n  def end(self, session, global_step):\n    self.end_count += 1\n\n  def get_counts(self):\n    return {\n        'begin': self.begin_count,\n        'before_save': self.before_save_count,\n        'after_save': self.after_save_count,\n        'end': self.end_count\n    }\n\n\n@test_util.deprecated_graph_mode_only\nclass SecondOrStepTimerTest(tf.test.TestCase):\n\n  def test_raise_in_both_secs_and_steps(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10)\n\n  def test_raise_in_none_secs_and_steps(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.SecondOrStepTimer()\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_every_secs(self, mock_time):\n    mock_time.return_value = MOCK_START_TIME\n    timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0)\n    self.assertTrue(timer.should_trigger_for_step(1))\n\n    timer.update_last_triggered_step(1)\n    self.assertFalse(timer.should_trigger_for_step(1))\n    self.assertFalse(timer.should_trigger_for_step(2))\n\n    mock_time.return_value += 1.0\n    self.assertFalse(timer.should_trigger_for_step(1))\n    self.assertTrue(timer.should_trigger_for_step(2))\n\n  def test_every_steps(self):\n    timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3)\n    self.assertTrue(timer.should_trigger_for_step(1))\n\n    timer.update_last_triggered_step(1)\n    self.assertFalse(timer.should_trigger_for_step(1))\n    self.assertFalse(timer.should_trigger_for_step(2))\n    self.assertFalse(timer.should_trigger_for_step(3))\n    self.assertTrue(timer.should_trigger_for_step(4))\n\n  def test_update_last_triggered_step(self):\n    timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1)\n\n    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1)\n    self.assertEqual(None, elapsed_secs)\n    self.assertEqual(None, elapsed_steps)\n\n    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(5)\n    self.assertLess(0, elapsed_secs)\n    self.assertEqual(4, elapsed_steps)\n\n    elapsed_secs, elapsed_steps = timer.update_last_triggered_step(7)\n    self.assertLess(0, elapsed_secs)\n    self.assertEqual(2, elapsed_steps)\n\n\n@test_util.deprecated_graph_mode_only\nclass StopAtStepTest(tf.test.TestCase):\n\n  def test_raise_in_both_last_step_and_num_steps(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.StopAtStepHook(num_steps=10, last_step=20)\n\n  def test_stop_based_on_last_step(self):\n    h = basic_session_run_hooks.StopAtStepHook(last_step=10)\n    with tf.Graph().as_default():\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      no_op = tf.no_op()\n      h.begin()\n      with tf.compat.v1.Session() as sess:\n        mon_sess = monitored_session._HookedSession(sess, [h])\n        sess.run(tf.compat.v1.assign(global_step, 5))\n        h.after_create_session(sess, None)\n        mon_sess.run(no_op)\n        self.assertFalse(mon_sess.should_stop())\n        sess.run(tf.compat.v1.assign(global_step, 9))\n        mon_sess.run(no_op)\n        self.assertFalse(mon_sess.should_stop())\n        sess.run(tf.compat.v1.assign(global_step, 10))\n        mon_sess.run(no_op)\n        self.assertTrue(mon_sess.should_stop())\n        sess.run(tf.compat.v1.assign(global_step, 11))\n        mon_sess._should_stop = False\n        mon_sess.run(no_op)\n        self.assertTrue(mon_sess.should_stop())\n\n  def test_stop_based_on_num_step(self):\n    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)\n\n    with tf.Graph().as_default():\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      no_op = tf.no_op()\n      h.begin()\n      with tf.compat.v1.Session() as sess:\n        mon_sess = monitored_session._HookedSession(sess, [h])\n        sess.run(tf.compat.v1.assign(global_step, 5))\n        h.after_create_session(sess, None)\n        mon_sess.run(no_op)\n        self.assertFalse(mon_sess.should_stop())\n        sess.run(tf.compat.v1.assign(global_step, 13))\n        mon_sess.run(no_op)\n        self.assertFalse(mon_sess.should_stop())\n        sess.run(tf.compat.v1.assign(global_step, 14))\n        mon_sess.run(no_op)\n        self.assertFalse(mon_sess.should_stop())\n        sess.run(tf.compat.v1.assign(global_step, 15))\n        mon_sess.run(no_op)\n        self.assertTrue(mon_sess.should_stop())\n        sess.run(tf.compat.v1.assign(global_step, 16))\n        mon_sess._should_stop = False\n        mon_sess.run(no_op)\n        self.assertTrue(mon_sess.should_stop())\n\n  def test_stop_based_with_multiple_steps(self):\n    h = basic_session_run_hooks.StopAtStepHook(num_steps=10)\n\n    with tf.Graph().as_default():\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      no_op = tf.no_op()\n      h.begin()\n      with tf.compat.v1.Session() as sess:\n        mon_sess = monitored_session._HookedSession(sess, [h])\n        sess.run(tf.compat.v1.assign(global_step, 5))\n        h.after_create_session(sess, None)\n        mon_sess.run(no_op)\n        self.assertFalse(mon_sess.should_stop())\n        sess.run(tf.compat.v1.assign(global_step, 15))\n        mon_sess.run(no_op)\n        self.assertTrue(mon_sess.should_stop())\n\n\n@test_util.deprecated_graph_mode_only\nclass LoggingTensorHookTest(tf.test.TestCase):\n\n  def setUp(self):\n    # Mock out logging calls so we can verify whether correct tensors are being\n    # monitored.\n    self._actual_log = tf_logging.info\n    self.logged_message = None\n\n    def mock_log(*args, **kwargs):\n      self.logged_message = args\n      self._actual_log(*args, **kwargs)\n\n    tf_logging.info = mock_log\n\n  def tearDown(self):\n    tf_logging.info = self._actual_log\n\n  def test_illegal_args(self):\n    with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):\n      basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=0)\n    with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):\n      basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=-10)\n    with self.assertRaisesRegexp(ValueError, 'xactly one of'):\n      basic_session_run_hooks.LoggingTensorHook(\n          tensors=['t'], every_n_iter=5, every_n_secs=5)\n    with self.assertRaisesRegexp(ValueError, 'xactly one of'):\n      basic_session_run_hooks.LoggingTensorHook(tensors=['t'])\n\n  def test_print_at_end_only(self):\n    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n      t = tf.constant(42.0, name='foo')\n      train_op = tf.constant(3)\n      hook = basic_session_run_hooks.LoggingTensorHook(\n          tensors=[t.name], at_end=True)\n      hook.begin()\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      self.logged_message = ''\n      for _ in range(3):\n        mon_sess.run(train_op)\n        # assertNotRegexpMatches is not supported by python 3.1 and later\n        self.assertEqual(str(self.logged_message).find(t.name), -1)\n\n      hook.end(sess)\n      self.assertRegexpMatches(str(self.logged_message), t.name)\n\n  def _validate_print_every_n_steps(self, sess, at_end):\n    t = tf.constant(42.0, name='foo')\n\n    train_op = tf.constant(3)\n    hook = basic_session_run_hooks.LoggingTensorHook(\n        tensors=[t.name], every_n_iter=10, at_end=at_end)\n    hook.begin()\n    mon_sess = monitored_session._HookedSession(sess, [hook])\n    self.evaluate(tf.compat.v1.initializers.global_variables())\n    mon_sess.run(train_op)\n    self.assertRegexpMatches(str(self.logged_message), t.name)\n    for _ in range(3):\n      self.logged_message = ''\n      for _ in range(9):\n        mon_sess.run(train_op)\n        # assertNotRegexpMatches is not supported by python 3.1 and later\n        self.assertEqual(str(self.logged_message).find(t.name), -1)\n      mon_sess.run(train_op)\n      self.assertRegexpMatches(str(self.logged_message), t.name)\n\n    # Add additional run to verify proper reset when called multiple times.\n    self.logged_message = ''\n    mon_sess.run(train_op)\n    # assertNotRegexpMatches is not supported by python 3.1 and later\n    self.assertEqual(str(self.logged_message).find(t.name), -1)\n\n    self.logged_message = ''\n    hook.end(sess)\n    if at_end:\n      self.assertRegexpMatches(str(self.logged_message), t.name)\n    else:\n      # assertNotRegexpMatches is not supported by python 3.1 and later\n      self.assertEqual(str(self.logged_message).find(t.name), -1)\n\n  def test_print_every_n_steps(self):\n    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n      self._validate_print_every_n_steps(sess, at_end=False)\n      # Verify proper reset.\n      self._validate_print_every_n_steps(sess, at_end=False)\n\n  def test_print_every_n_steps_and_end(self):\n    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n      self._validate_print_every_n_steps(sess, at_end=True)\n      # Verify proper reset.\n      self._validate_print_every_n_steps(sess, at_end=True)\n\n  def test_print_first_step(self):\n    # if it runs every iteration, first iteration has None duration.\n    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n      t = tf.constant(42.0, name='foo')\n      train_op = tf.constant(3)\n      hook = basic_session_run_hooks.LoggingTensorHook(\n          tensors={'foo': t}, every_n_iter=1)\n      hook.begin()\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess.run(train_op)\n      self.assertRegexpMatches(str(self.logged_message), 'foo')\n      # in first run, elapsed time is None.\n      self.assertEqual(str(self.logged_message).find('sec'), -1)\n\n  def _validate_print_every_n_secs(self, sess, at_end, mock_time):\n    t = tf.constant(42.0, name='foo')\n    train_op = tf.constant(3)\n\n    hook = basic_session_run_hooks.LoggingTensorHook(\n        tensors=[t.name], every_n_secs=1.0, at_end=at_end)\n    hook.begin()\n    mon_sess = monitored_session._HookedSession(sess, [hook])\n    self.evaluate(tf.compat.v1.initializers.global_variables())\n\n    mon_sess.run(train_op)\n    self.assertRegexpMatches(str(self.logged_message), t.name)\n\n    # assertNotRegexpMatches is not supported by python 3.1 and later\n    self.logged_message = ''\n    mon_sess.run(train_op)\n    self.assertEqual(str(self.logged_message).find(t.name), -1)\n    mock_time.return_value += 1.0\n\n    self.logged_message = ''\n    mon_sess.run(train_op)\n    self.assertRegexpMatches(str(self.logged_message), t.name)\n\n    self.logged_message = ''\n    hook.end(sess)\n    if at_end:\n      self.assertRegexpMatches(str(self.logged_message), t.name)\n    else:\n      # assertNotRegexpMatches is not supported by python 3.1 and later\n      self.assertEqual(str(self.logged_message).find(t.name), -1)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_print_every_n_secs(self, mock_time):\n    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n      mock_time.return_value = MOCK_START_TIME\n      self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time)\n      # Verify proper reset.\n      self._validate_print_every_n_secs(sess, at_end=False, mock_time=mock_time)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_print_every_n_secs_and_end(self, mock_time):\n    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n      mock_time.return_value = MOCK_START_TIME\n      self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time)\n      # Verify proper reset.\n      self._validate_print_every_n_secs(sess, at_end=True, mock_time=mock_time)\n\n  def test_print_formatter(self):\n    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n      t = tf.constant(42.0, name='foo')\n      train_op = tf.constant(3)\n      hook = basic_session_run_hooks.LoggingTensorHook(\n          tensors=[t.name],\n          every_n_iter=10,\n          formatter=lambda items: 'qqq=%s' % items[t.name])\n      hook.begin()\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess.run(train_op)\n      self.assertEqual(self.logged_message[0], 'qqq=42.0')\n\n\n@test_util.deprecated_graph_mode_only\nclass CheckpointSaverHookTest(tf.test.TestCase):\n\n  def setUp(self):\n    self.model_dir = tempfile.mkdtemp()\n    self.graph = tf.Graph()\n    with self.graph.as_default():\n      self.scaffold = tf.compat.v1.train.Scaffold()\n      self.global_step = tf.compat.v1.train.get_or_create_global_step()\n      self.train_op = training_util._increment_global_step(1)\n\n  def tearDown(self):\n    shutil.rmtree(self.model_dir, ignore_errors=True)\n\n  def test_saves_when_saver_and_scaffold_both_missing(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=1)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)\n        self.assertEqual(\n            1, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n  def test_raise_when_saver_and_scaffold_both_present(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold)\n\n  def test_raise_in_both_secs_and_steps(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_secs=10, save_steps=20)\n\n  def test_raise_in_none_secs_and_steps(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.CheckpointSaverHook(self.model_dir)\n\n  def test_save_secs_saves_in_first_step(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_secs=2, scaffold=self.scaffold)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)\n        self.assertEqual(\n            1, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n  def test_save_secs_calls_listeners_at_begin_and_end(self):\n    with self.graph.as_default():\n      listener = MockCheckpointSaverListener()\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir,\n          save_secs=2,\n          scaffold=self.scaffold,\n          listeners=[listener])\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)  # hook runs here\n        mon_sess.run(self.train_op)  # hook won't run here, so it does at end\n        hook.end(sess)  # hook runs here\n      self.assertEqual({\n          'begin': 1,\n          'before_save': 2,\n          'after_save': 2,\n          'end': 1\n      }, listener.get_counts())\n\n  def test_listener_with_monitored_session(self):\n    with tf.Graph().as_default():\n      scaffold = tf.compat.v1.train.Scaffold()\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      train_op = training_util._increment_global_step(1)\n      listener = MockCheckpointSaverListener()\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener])\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook], scaffold=scaffold,\n          checkpoint_dir=self.model_dir) as sess:\n        sess.run(train_op)\n        sess.run(train_op)\n        global_step_val = sess.raw_session().run(global_step)\n      listener_counts = listener.get_counts()\n    self.assertEqual(2, global_step_val)\n    self.assertEqual({\n        'begin': 1,\n        'before_save': 3,\n        'after_save': 3,\n        'end': 1\n    }, listener_counts)\n\n  def test_listener_stops_training_in_after_save(self):\n    with tf.Graph().as_default():\n      scaffold = tf.compat.v1.train.Scaffold()\n      tf.compat.v1.train.get_or_create_global_step()\n      train_op = training_util._increment_global_step(1)\n      listener = MockCheckpointSaverListener()\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=1, scaffold=scaffold, listeners=[listener])\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook], scaffold=scaffold,\n          checkpoint_dir=self.model_dir) as sess:\n        sess.run(train_op)\n        self.assertFalse(sess.should_stop())\n        sess.run(train_op)\n        self.assertFalse(sess.should_stop())\n        listener.ask_for_stop = True\n        sess.run(train_op)\n        self.assertTrue(sess.should_stop())\n\n  def test_listener_with_default_saver(self):\n    with tf.Graph().as_default():\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      train_op = training_util._increment_global_step(1)\n      listener = MockCheckpointSaverListener()\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=1, listeners=[listener])\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook], checkpoint_dir=self.model_dir) as sess:\n        sess.run(train_op)\n        sess.run(train_op)\n        global_step_val = sess.raw_session().run(global_step)\n      listener_counts = listener.get_counts()\n    self.assertEqual(2, global_step_val)\n    self.assertEqual({\n        'begin': 1,\n        'before_save': 3,\n        'after_save': 3,\n        'end': 1\n    }, listener_counts)\n\n    with tf.Graph().as_default():\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      with tf.compat.v1.train.SingularMonitoredSession(\n          checkpoint_dir=self.model_dir) as sess2:\n        global_step_saved_val = sess2.run(global_step)\n    self.assertEqual(2, global_step_saved_val)\n\n  def test_two_listeners_with_default_saver(self):\n    with tf.Graph().as_default():\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      train_op = training_util._increment_global_step(1)\n      listener1 = MockCheckpointSaverListener()\n      listener2 = MockCheckpointSaverListener()\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=1, listeners=[listener1, listener2])\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook], checkpoint_dir=self.model_dir) as sess:\n        sess.run(train_op)\n        sess.run(train_op)\n        global_step_val = sess.raw_session().run(global_step)\n      listener1_counts = listener1.get_counts()\n      listener2_counts = listener2.get_counts()\n    self.assertEqual(2, global_step_val)\n    self.assertEqual({\n        'begin': 1,\n        'before_save': 3,\n        'after_save': 3,\n        'end': 1\n    }, listener1_counts)\n    self.assertEqual(listener1_counts, listener2_counts)\n\n    with tf.Graph().as_default():\n      global_step = tf.compat.v1.train.get_or_create_global_step()\n      with tf.compat.v1.train.SingularMonitoredSession(\n          checkpoint_dir=self.model_dir) as sess2:\n        global_step_saved_val = sess2.run(global_step)\n    self.assertEqual(2, global_step_saved_val)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_save_secs_saves_periodically(self, mock_time):\n    with self.graph.as_default():\n      mock_time.return_value = MOCK_START_TIME\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_secs=2, scaffold=self.scaffold)\n      hook.begin()\n      self.scaffold.finalize()\n\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n\n        mock_time.return_value = MOCK_START_TIME\n        mon_sess.run(self.train_op)  # Saved.\n\n        mock_time.return_value = MOCK_START_TIME + 0.5\n        mon_sess.run(self.train_op)  # Not saved.\n\n        self.assertEqual(\n            1, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n        # Simulate 2.5 seconds of sleep.\n        mock_time.return_value = MOCK_START_TIME + 2.5\n        mon_sess.run(self.train_op)  # Saved.\n\n        mock_time.return_value = MOCK_START_TIME + 2.6\n        mon_sess.run(self.train_op)  # Not saved.\n\n        mock_time.return_value = MOCK_START_TIME + 2.7\n        mon_sess.run(self.train_op)  # Not saved.\n\n        self.assertEqual(\n            3, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n        # Simulate 7.5 more seconds of sleep (10 seconds from start.\n        mock_time.return_value = MOCK_START_TIME + 10\n        mon_sess.run(self.train_op)  # Saved.\n        self.assertEqual(\n            6, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_save_secs_calls_listeners_periodically(self, mock_time):\n    with self.graph.as_default():\n      mock_time.return_value = MOCK_START_TIME\n      listener = MockCheckpointSaverListener()\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir,\n          save_secs=2,\n          scaffold=self.scaffold,\n          listeners=[listener])\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n\n        mock_time.return_value = MOCK_START_TIME + 0.5\n        mon_sess.run(self.train_op)  # hook runs here\n\n        mock_time.return_value = MOCK_START_TIME + 0.5\n        mon_sess.run(self.train_op)\n\n        mock_time.return_value = MOCK_START_TIME + 3.0\n        mon_sess.run(self.train_op)  # hook runs here\n\n        mock_time.return_value = MOCK_START_TIME + 3.5\n        mon_sess.run(self.train_op)\n\n        mock_time.return_value = MOCK_START_TIME + 4.0\n        mon_sess.run(self.train_op)\n\n        mock_time.return_value = MOCK_START_TIME + 6.5\n        mon_sess.run(self.train_op)  # hook runs here\n\n        mock_time.return_value = MOCK_START_TIME + 7.0\n        mon_sess.run(self.train_op)  # hook won't run here, so it does at end\n\n        mock_time.return_value = MOCK_START_TIME + 7.5\n        hook.end(sess)  # hook runs here\n      self.assertEqual({\n          'begin': 1,\n          'before_save': 4,\n          'after_save': 4,\n          'end': 1\n      }, listener.get_counts())\n\n  def test_save_steps_saves_in_first_step(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=2, scaffold=self.scaffold)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)\n        self.assertEqual(\n            1, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n  def test_save_steps_saves_periodically(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=2, scaffold=self.scaffold)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)\n        mon_sess.run(self.train_op)\n        # Not saved\n        self.assertEqual(\n            1, tf.train.load_variable(self.model_dir, self.global_step.name))\n        mon_sess.run(self.train_op)\n        # saved\n        self.assertEqual(\n            3, tf.train.load_variable(self.model_dir, self.global_step.name))\n        mon_sess.run(self.train_op)\n        # Not saved\n        self.assertEqual(\n            3, tf.train.load_variable(self.model_dir, self.global_step.name))\n        mon_sess.run(self.train_op)\n        # saved\n        self.assertEqual(\n            5, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n  def test_save_saves_at_end(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_secs=2, scaffold=self.scaffold)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)\n        mon_sess.run(self.train_op)\n        hook.end(sess)\n        self.assertEqual(\n            2, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n  def test_summary_writer_defs(self):\n    fake_summary_writer.FakeSummaryWriter.install()\n    tf.compat.v1.summary.FileWriterCache.clear()\n    summary_writer = tf.compat.v1.summary.FileWriterCache.get(self.model_dir)\n\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=2, scaffold=self.scaffold)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        hook.after_create_session(sess, None)\n        mon_sess.run(self.train_op)\n      summary_writer.assert_summaries(\n          test_case=self,\n          expected_logdir=self.model_dir,\n          expected_added_meta_graphs=[\n              meta_graph.create_meta_graph_def(\n                  graph_def=self.graph.as_graph_def(add_shapes=True),\n                  saver_def=self.scaffold.saver.saver_def)\n          ])\n\n    fake_summary_writer.FakeSummaryWriter.uninstall()\n\n  def test_save_checkpoint_before_first_train_step(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=2, scaffold=self.scaffold)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        sess.run(self.scaffold.init_op)\n        hook.after_create_session(sess, None)\n        # Verifies that checkpoint is saved at step 0.\n        self.assertEqual(\n            0, tf.train.load_variable(self.model_dir, self.global_step.name))\n        # Verifies that no checkpoint is saved after one training step.\n        mon_sess.run(self.train_op)\n        self.assertEqual(\n            0, tf.train.load_variable(self.model_dir, self.global_step.name))\n        # Verifies that checkpoint is saved after save_steps.\n        mon_sess.run(self.train_op)\n        self.assertEqual(\n            2, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n\n@test_util.deprecated_graph_mode_only\nclass CheckpointSaverHookMultiStepTest(tf.test.TestCase):\n\n  def setUp(self):\n    self.model_dir = tempfile.mkdtemp()\n    self.graph = tf.Graph()\n    self.steps_per_run = 5\n    with self.graph.as_default():\n      self.scaffold = tf.compat.v1.train.Scaffold()\n      self.global_step = tf.compat.v1.train.get_or_create_global_step()\n      self.train_op = training_util._increment_global_step(self.steps_per_run)\n\n  def tearDown(self):\n    shutil.rmtree(self.model_dir, ignore_errors=True)\n\n  def test_save_steps_saves_in_first_step(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir,\n          save_steps=2 * self.steps_per_run,\n          scaffold=self.scaffold)\n      hook._set_steps_per_run(self.steps_per_run)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)\n        self.assertEqual(\n            5, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n  def test_save_steps_saves_periodically(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir,\n          save_steps=2 * self.steps_per_run,\n          scaffold=self.scaffold)\n      hook._set_steps_per_run(self.steps_per_run)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)\n        # Saved (step=5)\n        self.assertEqual(\n            5, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n        mon_sess.run(self.train_op)\n        # Not saved (step=10)\n        self.assertEqual(\n            5, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n        mon_sess.run(self.train_op)\n        # Saved (step=15)\n        self.assertEqual(\n            15, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n        mon_sess.run(self.train_op)\n        # Not saved (step=20)\n        self.assertEqual(\n            15, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n        mon_sess.run(self.train_op)\n        # Saved (step=25)\n        self.assertEqual(\n            25, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n  def test_save_steps_saves_at_end(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir,\n          save_steps=2 * self.steps_per_run,\n          scaffold=self.scaffold)\n      hook._set_steps_per_run(self.steps_per_run)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)\n        mon_sess.run(self.train_op)\n        hook.end(sess)\n        self.assertEqual(\n            10, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n\n@test_util.deprecated_graph_mode_only\nclass ResourceCheckpointSaverHookTest(tf.test.TestCase):\n\n  def setUp(self):\n    self.model_dir = tempfile.mkdtemp()\n    self.graph = tf.Graph()\n    with self.graph.as_default():\n      self.scaffold = tf.compat.v1.train.Scaffold()\n      with tf.compat.v1.variable_scope('foo', use_resource=True):\n        self.global_step = tf.compat.v1.train.get_or_create_global_step()\n      self.train_op = training_util._increment_global_step(1)\n\n  def test_save_steps_saves_periodically(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.CheckpointSaverHook(\n          self.model_dir, save_steps=2, scaffold=self.scaffold)\n      hook.begin()\n      self.scaffold.finalize()\n      with tf.compat.v1.Session() as sess:\n        sess.run(self.scaffold.init_op)\n        mon_sess = monitored_session._HookedSession(sess, [hook])\n        mon_sess.run(self.train_op)\n        mon_sess.run(self.train_op)\n        # Not saved\n        self.assertEqual(\n            1, tf.train.load_variable(self.model_dir, self.global_step.name))\n        mon_sess.run(self.train_op)\n        # saved\n        self.assertEqual(\n            3, tf.train.load_variable(self.model_dir, self.global_step.name))\n        mon_sess.run(self.train_op)\n        # Not saved\n        self.assertEqual(\n            3, tf.train.load_variable(self.model_dir, self.global_step.name))\n        mon_sess.run(self.train_op)\n        # saved\n        self.assertEqual(\n            5, tf.train.load_variable(self.model_dir, self.global_step.name))\n\n\n@test_util.deprecated_graph_mode_only\nclass StepCounterHookTest(tf.test.TestCase):\n\n  def setUp(self):\n    self.log_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    shutil.rmtree(self.log_dir, ignore_errors=True)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_step_counter_every_n_steps(self, mock_time):\n    mock_time.return_value = MOCK_START_TIME\n    with tf.Graph().as_default() as g, tf.compat.v1.Session() as sess:\n      tf.compat.v1.train.get_or_create_global_step()\n      train_op = training_util._increment_global_step(1)\n      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)\n      hook = basic_session_run_hooks.StepCounterHook(\n          summary_writer=summary_writer, every_n_steps=10)\n      hook.begin()\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      with tf.compat.v1.test.mock.patch.object(tf_logging,\n                                               'warning') as mock_log:\n        for _ in range(30):\n          mock_time.return_value += 0.01\n          mon_sess.run(train_op)\n        # logging.warning should not be called.\n        self.assertIsNone(mock_log.call_args)\n      hook.end(sess)\n      summary_writer.assert_summaries(\n          test_case=self,\n          expected_logdir=self.log_dir,\n          expected_graph=g,\n          expected_summaries={})\n      self.assertItemsEqual([11, 21], summary_writer.summaries.keys())\n      for step in [11, 21]:\n        summary_value = summary_writer.summaries[step][0].value[0]\n        self.assertEqual('global_step/sec', summary_value.tag)\n        self.assertGreater(summary_value.simple_value, 0)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_step_counter_every_n_secs(self, mock_time):\n    mock_time.return_value = MOCK_START_TIME\n    with tf.Graph().as_default() as g, tf.compat.v1.Session() as sess:\n      tf.compat.v1.train.get_or_create_global_step()\n      train_op = training_util._increment_global_step(1)\n      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)\n      hook = basic_session_run_hooks.StepCounterHook(\n          summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)\n\n      hook.begin()\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      mon_sess.run(train_op)\n      mock_time.return_value += 0.2\n      mon_sess.run(train_op)\n      mock_time.return_value += 0.2\n      mon_sess.run(train_op)\n      hook.end(sess)\n\n      summary_writer.assert_summaries(\n          test_case=self,\n          expected_logdir=self.log_dir,\n          expected_graph=g,\n          expected_summaries={})\n      self.assertTrue(summary_writer.summaries, 'No summaries were created.')\n      self.assertItemsEqual([2, 3], summary_writer.summaries.keys())\n      for summary in summary_writer.summaries.values():\n        summary_value = summary[0].value[0]\n        self.assertEqual('global_step/sec', summary_value.tag)\n        self.assertGreater(summary_value.simple_value, 0)\n\n  def test_global_step_name(self):\n    with tf.Graph().as_default() as g, tf.compat.v1.Session() as sess:\n      with tf.compat.v1.variable_scope('bar'):\n        tf.compat.v1.get_variable(\n            'foo',\n            initializer=0,\n            trainable=False,\n            collections=[\n                tf.compat.v1.GraphKeys.GLOBAL_STEP,\n                tf.compat.v1.GraphKeys.GLOBAL_VARIABLES\n            ])\n      train_op = training_util._increment_global_step(1)\n      summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)\n      hook = basic_session_run_hooks.StepCounterHook(\n          summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)\n\n      hook.begin()\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      mon_sess.run(train_op)\n      mon_sess.run(train_op)\n      hook.end(sess)\n\n      summary_writer.assert_summaries(\n          test_case=self,\n          expected_logdir=self.log_dir,\n          expected_graph=g,\n          expected_summaries={})\n      self.assertTrue(summary_writer.summaries, 'No summaries were created.')\n      self.assertItemsEqual([2], summary_writer.summaries.keys())\n      summary_value = summary_writer.summaries[2][0].value[0]\n      self.assertEqual('bar/foo/sec', summary_value.tag)\n\n  def test_log_warning_if_global_step_not_increased(self):\n    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n      tf.compat.v1.train.get_or_create_global_step()\n      train_op = training_util._increment_global_step(0)  # keep same.\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      hook = basic_session_run_hooks.StepCounterHook(\n          every_n_steps=1, every_n_secs=None)\n      hook.begin()\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      mon_sess.run(train_op)  # Run one step to record global step.\n      with tf.compat.v1.test.mock.patch.object(tf_logging,\n                                               'log_first_n') as mock_log:\n        for _ in range(30):\n          mon_sess.run(train_op)\n        self.assertRegexpMatches(\n            str(mock_log.call_args), 'global step.*has not been increased')\n      hook.end(sess)\n\n  def _setup_steps_per_run_test(self, every_n_steps, steps_per_run, graph,\n                                sess):\n    tf.compat.v1.train.get_or_create_global_step()\n    self.train_op = training_util._increment_global_step(steps_per_run)\n    self.summary_writer = fake_summary_writer.FakeSummaryWriter(\n        self.log_dir, graph)\n    self.hook = basic_session_run_hooks.StepCounterHook(\n        summary_writer=self.summary_writer, every_n_steps=every_n_steps)\n    self.hook._set_steps_per_run(steps_per_run)\n    self.hook.begin()\n    self.evaluate(tf.compat.v1.initializers.global_variables())\n    self.mon_sess = monitored_session._HookedSession(sess, [self.hook])\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_steps_per_run_less_than_every_n_steps(self, mock_time):\n    mock_time.return_value = MOCK_START_TIME\n    with tf.Graph().as_default() as g, tf.compat.v1.Session() as sess:\n      self._setup_steps_per_run_test(10, 5, g, sess)\n\n      # Logs at 15, 25\n      for _ in range(5):\n        mock_time.return_value += 0.01\n        self.mon_sess.run(self.train_op)\n\n      self.hook.end(sess)\n      self.summary_writer.assert_summaries(\n          test_case=self,\n          expected_logdir=self.log_dir,\n          expected_graph=g,\n          expected_summaries={})\n      self.assertItemsEqual([15, 25], self.summary_writer.summaries.keys())\n      for step in [15, 25]:\n        summary_value = self.summary_writer.summaries[step][0].value[0]\n        self.assertEqual('global_step/sec', summary_value.tag)\n        self.assertGreater(summary_value.simple_value, 0)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_steps_per_run_equal_every_n_steps(self, mock_time):\n    mock_time.return_value = MOCK_START_TIME\n    with tf.Graph().as_default() as g, tf.compat.v1.Session() as sess:\n      self._setup_steps_per_run_test(5, 5, g, sess)\n\n      # Logs at 10, 15, 20, 25\n      for _ in range(5):\n        mock_time.return_value += 0.01\n        self.mon_sess.run(self.train_op)\n\n      self.hook.end(sess)\n      self.summary_writer.assert_summaries(\n          test_case=self,\n          expected_logdir=self.log_dir,\n          expected_graph=g,\n          expected_summaries={})\n      self.assertItemsEqual([10, 15, 20, 25],\n                            self.summary_writer.summaries.keys())\n      for step in [10, 15, 20, 25]:\n        summary_value = self.summary_writer.summaries[step][0].value[0]\n        self.assertEqual('global_step/sec', summary_value.tag)\n        self.assertGreater(summary_value.simple_value, 0)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_steps_per_run_greater_than_every_n_steps(self, mock_time):\n    mock_time.return_value = MOCK_START_TIME\n    with tf.Graph().as_default() as g, tf.compat.v1.Session() as sess:\n      self._setup_steps_per_run_test(5, 10, g, sess)\n\n      # Logs at 20, 30, 40, 50\n      for _ in range(5):\n        mock_time.return_value += 0.01\n        self.mon_sess.run(self.train_op)\n\n      self.hook.end(sess)\n      self.summary_writer.assert_summaries(\n          test_case=self,\n          expected_logdir=self.log_dir,\n          expected_graph=g,\n          expected_summaries={})\n      self.assertItemsEqual([20, 30, 40, 50],\n                            self.summary_writer.summaries.keys())\n      for step in [20, 30, 40, 50]:\n        summary_value = self.summary_writer.summaries[step][0].value[0]\n        self.assertEqual('global_step/sec', summary_value.tag)\n        self.assertGreater(summary_value.simple_value, 0)\n\n\n@test_util.deprecated_graph_mode_only\nclass SummarySaverHookTest(tf.test.TestCase):\n\n  def setUp(self):\n    tf.test.TestCase.setUp(self)\n\n    self.log_dir = 'log/dir'\n    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)\n\n    var = tf.Variable(0.0)\n    tensor = tf.compat.v1.assign_add(var, 1.0)\n    tensor2 = tensor * 2\n    self.summary_op = tf.compat.v1.summary.scalar('my_summary', tensor)\n    self.summary_op2 = tf.compat.v1.summary.scalar('my_summary2', tensor2)\n\n    tf.compat.v1.train.get_or_create_global_step()\n    self.train_op = training_util._increment_global_step(1)\n\n  def test_raise_when_scaffold_and_summary_op_both_missing(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.SummarySaverHook()\n\n  def test_raise_when_scaffold_and_summary_op_both_present(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.SummarySaverHook(\n          scaffold=tf.compat.v1.train.Scaffold(), summary_op=self.summary_op)\n\n  def test_raise_in_both_secs_and_steps(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.SummarySaverHook(\n          save_secs=10, save_steps=20, summary_writer=self.summary_writer)\n\n  def test_raise_in_none_secs_and_steps(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.SummarySaverHook(\n          save_secs=None, save_steps=None, summary_writer=self.summary_writer)\n\n  def test_save_steps(self):\n    hook = basic_session_run_hooks.SummarySaverHook(\n        save_steps=8,\n        summary_writer=self.summary_writer,\n        summary_op=self.summary_op)\n\n    with self.cached_session() as sess:\n      hook.begin()\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      for _ in range(30):\n        mon_sess.run(self.train_op)\n      hook.end(sess)\n\n    self.summary_writer.assert_summaries(\n        test_case=self,\n        expected_logdir=self.log_dir,\n        expected_summaries={\n            1: {\n                'my_summary': 1.0\n            },\n            9: {\n                'my_summary': 2.0\n            },\n            17: {\n                'my_summary': 3.0\n            },\n            25: {\n                'my_summary': 4.0\n            },\n        })\n\n  def test_multiple_summaries(self):\n    hook = basic_session_run_hooks.SummarySaverHook(\n        save_steps=8,\n        summary_writer=self.summary_writer,\n        summary_op=[self.summary_op, self.summary_op2])\n\n    with self.cached_session() as sess:\n      hook.begin()\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      for _ in range(10):\n        mon_sess.run(self.train_op)\n      hook.end(sess)\n\n    self.summary_writer.assert_summaries(\n        test_case=self,\n        expected_logdir=self.log_dir,\n        expected_summaries={\n            1: {\n                'my_summary': 1.0,\n                'my_summary2': 2.0\n            },\n            9: {\n                'my_summary': 2.0,\n                'my_summary2': 4.0\n            },\n        })\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_save_secs_saving_once_every_step(self, mock_time):\n    mock_time.return_value = MOCK_START_TIME\n    hook = basic_session_run_hooks.SummarySaverHook(\n        save_secs=0.5,\n        summary_writer=self.summary_writer,\n        summary_op=self.summary_op)\n\n    with self.cached_session() as sess:\n      hook.begin()\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      for _ in range(4):\n        mon_sess.run(self.train_op)\n        mock_time.return_value += 0.5\n      hook.end(sess)\n\n    self.summary_writer.assert_summaries(\n        test_case=self,\n        expected_logdir=self.log_dir,\n        expected_summaries={\n            1: {\n                'my_summary': 1.0\n            },\n            2: {\n                'my_summary': 2.0\n            },\n            3: {\n                'my_summary': 3.0\n            },\n            4: {\n                'my_summary': 4.0\n            },\n        })\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_save_secs_saving_once_every_three_steps(self, mock_time):\n    mock_time.return_value = 1484695987.209386\n    hook = basic_session_run_hooks.SummarySaverHook(\n        save_secs=9.,\n        summary_writer=self.summary_writer,\n        summary_op=self.summary_op)\n\n    with self.cached_session() as sess:\n      hook.begin()\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      for _ in range(8):\n        mon_sess.run(self.train_op)\n        mock_time.return_value += 3.1\n      hook.end(sess)\n\n    # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:\n    self.summary_writer.assert_summaries(\n        test_case=self,\n        expected_logdir=self.log_dir,\n        expected_summaries={\n            1: {\n                'my_summary': 1.0\n            },\n            4: {\n                'my_summary': 2.0\n            },\n            7: {\n                'my_summary': 3.0\n            },\n        })\n\n\n@test_util.deprecated_graph_mode_only\nclass GlobalStepWaiterHookTest(tf.test.TestCase):\n\n  def test_not_wait_for_step_zero(self):\n    with tf.Graph().as_default():\n      tf.compat.v1.train.get_or_create_global_step()\n      hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)\n      hook.begin()\n      with tf.compat.v1.Session() as sess:\n        # Before run should return without waiting gstep increment.\n        hook.before_run(\n            tf.compat.v1.train.SessionRunContext(\n                original_args=None, session=sess))\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  def test_wait_for_step(self, mock_sleep):\n    with tf.Graph().as_default():\n      gstep = tf.compat.v1.train.get_or_create_global_step()\n      hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)\n      hook.begin()\n\n      with tf.compat.v1.Session() as sess:\n        # Mock out calls to time.sleep() to update the global step.\n\n        class Context(object):\n          counter = 0\n\n        def mock_sleep_side_effect(seconds):\n          del seconds  # argument is ignored\n          Context.counter += 1\n          if Context.counter == 1:\n            # The first time sleep() is called, we update the global_step from\n            # 0 to 500.\n            sess.run(tf.compat.v1.assign(gstep, 500))\n          elif Context.counter == 2:\n            # The second time sleep() is called, we update the global_step from\n            # 500 to 1100.\n            sess.run(tf.compat.v1.assign(gstep, 1100))\n          else:\n            raise AssertionError(\n                'Expected before_run() to terminate after the second call to '\n                'time.sleep()')\n\n        mock_sleep.side_effect = mock_sleep_side_effect\n\n        # Run the mocked-out interaction with the hook.\n        self.evaluate(tf.compat.v1.initializers.global_variables())\n        run_context = tf.compat.v1.train.SessionRunContext(\n            original_args=None, session=sess)\n        hook.before_run(run_context)\n        self.assertEqual(Context.counter, 2)\n\n\n@test_util.deprecated_graph_mode_only\nclass FinalOpsHookTest(tf.test.TestCase):\n\n  def test_final_ops_is_scalar_tensor(self):\n    with tf.Graph().as_default():\n      expected_value = 4\n      final_ops = tf.constant(expected_value)\n\n      hook = basic_session_run_hooks.FinalOpsHook(final_ops)\n      hook.begin()\n\n      with tf.compat.v1.Session() as session:\n        hook.end(session)\n        self.assertEqual(expected_value, hook.final_ops_values)\n\n  def test_final_ops_is_tensor(self):\n    with tf.Graph().as_default():\n      expected_values = [1, 6, 3, 5, 2, 4]\n      final_ops = tf.constant(expected_values)\n\n      hook = basic_session_run_hooks.FinalOpsHook(final_ops)\n      hook.begin()\n\n      with tf.compat.v1.Session() as session:\n        hook.end(session)\n        self.assertListEqual(expected_values, hook.final_ops_values.tolist())\n\n  def test_final_ops_triggers_out_of_range_error(self):\n    with tf.Graph().as_default():\n      dataset = tf.compat.v1.data.Dataset.range(1)\n      iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)\n      read_ops = iterator.get_next()\n      final_ops = read_ops\n\n      hook = basic_session_run_hooks.FinalOpsHook(final_ops)\n      hook.begin()\n\n      with tf.compat.v1.Session() as session:\n        session.run(read_ops)\n        with tf.compat.v1.test.mock.patch.object(tf_logging,\n                                                 'warning') as mock_log:\n          with self.assertRaisesRegexp(tf.errors.OutOfRangeError,\n                                       'End of sequence'):\n            hook.end(session)\n          self.assertRegexpMatches(\n              str(mock_log.call_args), 'dependency back to some input source')\n\n  def test_final_ops_with_dictionary(self):\n    with tf.Graph().as_default():\n      expected_values = [4, -3]\n      final_ops = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n      final_ops_feed_dict = {final_ops: expected_values}\n\n      hook = basic_session_run_hooks.FinalOpsHook(final_ops,\n                                                  final_ops_feed_dict)\n      hook.begin()\n\n      with tf.compat.v1.Session() as session:\n        hook.end(session)\n        self.assertListEqual(expected_values, hook.final_ops_values.tolist())\n\n\n@test_util.deprecated_graph_mode_only\nclass ResourceSummarySaverHookTest(tf.test.TestCase):\n\n  def setUp(self):\n    tf.test.TestCase.setUp(self)\n\n    self.log_dir = 'log/dir'\n    self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)\n\n    var = tf.compat.v1.get_variable('var', initializer=0.0, use_resource=True)\n    tensor = tf.compat.v1.assign_add(var, 1.0)\n    self.summary_op = tf.compat.v1.summary.scalar('my_summary', tensor)\n\n    with tf.compat.v1.variable_scope('foo', use_resource=True):\n      tf.compat.v1.train.create_global_step()\n    self.train_op = training_util._increment_global_step(1)\n\n  def test_save_steps(self):\n    hook = basic_session_run_hooks.SummarySaverHook(\n        save_steps=8,\n        summary_writer=self.summary_writer,\n        summary_op=self.summary_op)\n\n    with self.cached_session() as sess:\n      hook.begin()\n      self.evaluate(tf.compat.v1.initializers.global_variables())\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      for _ in range(30):\n        mon_sess.run(self.train_op)\n      hook.end(sess)\n\n    self.summary_writer.assert_summaries(\n        test_case=self,\n        expected_logdir=self.log_dir,\n        expected_summaries={\n            1: {\n                'my_summary': 1.0\n            },\n            9: {\n                'my_summary': 2.0\n            },\n            17: {\n                'my_summary': 3.0\n            },\n            25: {\n                'my_summary': 4.0\n            },\n        })\n\n\n@test_util.deprecated_graph_mode_only\nclass FeedFnHookTest(tf.test.TestCase):\n\n  def test_feeding_placeholder(self):\n    with tf.Graph().as_default(), tf.compat.v1.Session() as sess:\n      x = tf.compat.v1.placeholder(dtype=tf.dtypes.float32)\n      y = x + 1\n      hook = basic_session_run_hooks.FeedFnHook(feed_fn=lambda: {x: 1.0})\n      hook.begin()\n      mon_sess = monitored_session._HookedSession(sess, [hook])\n      self.assertEqual(mon_sess.run(y), 2)\n\n\n@test_util.deprecated_graph_mode_only\nclass ProfilerHookTest(tf.test.TestCase):\n\n  def setUp(self):\n    super(ProfilerHookTest, self).setUp()\n    self.output_dir = tempfile.mkdtemp()\n    self.graph = tf.Graph()\n    self.filepattern = os.path.join(self.output_dir, 'timeline-*.json')\n    with self.graph.as_default():\n      self.global_step = tf.compat.v1.train.get_or_create_global_step()\n      self.train_op = tf.compat.v1.assign_add(self.global_step, 1)\n\n  def tearDown(self):\n    super(ProfilerHookTest, self).tearDown()\n    shutil.rmtree(self.output_dir, ignore_errors=True)\n\n  def _count_timeline_files(self):\n    return len(tf.compat.v1.gfile.Glob(self.filepattern))\n\n  def test_raise_in_both_secs_and_steps(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.ProfilerHook(save_secs=10, save_steps=20)\n\n  def test_raise_in_none_secs_and_steps(self):\n    with self.assertRaises(ValueError):\n      basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None)\n\n  def test_save_secs_does_not_save_in_first_step(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.ProfilerHook(\n          save_secs=2, output_dir=self.output_dir)\n      with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess:\n        sess.run(self.train_op)\n        self.assertEqual(0, self._count_timeline_files())\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_save_secs_saves_periodically(self, mock_time):\n    # Pick a fixed start time.\n    with self.graph.as_default():\n      mock_time.return_value = MOCK_START_TIME\n      hook = basic_session_run_hooks.ProfilerHook(\n          save_secs=2, output_dir=self.output_dir)\n      with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess:\n        sess.run(self.train_op)  # Not saved.\n        self.assertEqual(0, self._count_timeline_files())\n        # Simulate 2.5 seconds of sleep.\n        mock_time.return_value = MOCK_START_TIME + 2.5\n        sess.run(self.train_op)  # Saved.\n        self.assertEqual(1, self._count_timeline_files())\n\n        # Pretend some small amount of time has passed.\n        mock_time.return_value = MOCK_START_TIME + 2.6\n        sess.run(self.train_op)  # Not saved.\n        # Edge test just before we should save the timeline.\n        mock_time.return_value = MOCK_START_TIME + 4.4\n        sess.run(self.train_op)  # Not saved.\n        self.assertEqual(1, self._count_timeline_files())\n\n        mock_time.return_value = MOCK_START_TIME + 4.5\n        sess.run(self.train_op)  # Saved.\n        self.assertEqual(2, self._count_timeline_files())\n\n  def test_save_steps_does_not_save_in_first_step(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.ProfilerHook(\n          save_steps=1, output_dir=self.output_dir)\n      with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess:\n        sess.run(self.train_op)  # Not saved.\n        self.assertEqual(0, self._count_timeline_files())\n\n  def test_save_steps_saves_periodically(self):\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.ProfilerHook(\n          save_steps=2, output_dir=self.output_dir)\n      with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess:\n        self.assertEqual(0, self._count_timeline_files())\n        sess.run(self.train_op)  # Not saved.\n        self.assertEqual(0, self._count_timeline_files())\n        sess.run(self.train_op)  # Saved.\n        self.assertEqual(1, self._count_timeline_files())\n        sess.run(self.train_op)  # Not saved.\n        self.assertEqual(1, self._count_timeline_files())\n        sess.run(self.train_op)  # Saved.\n        self.assertEqual(2, self._count_timeline_files())\n        sess.run(self.train_op)  # Not saved.\n        self.assertEqual(2, self._count_timeline_files())\n\n  def test_run_metadata_saves(self):\n    tf.compat.v1.summary.FileWriterCache.clear()\n    fake_summary_writer.FakeSummaryWriter.install()\n    fake_writer = tf.compat.v1.summary.FileWriterCache.get(self.output_dir)\n    with self.graph.as_default():\n      hook = basic_session_run_hooks.ProfilerHook(\n          save_steps=1, output_dir=self.output_dir)\n      with tf.compat.v1.train.SingularMonitoredSession(hooks=[hook]) as sess:\n        sess.run(self.train_op)  # Not saved.\n        sess.run(self.train_op)  # Saved.\n        self.assertEqual(\n            list(fake_writer._added_run_metadata.keys()), ['step_2'])\n    fake_summary_writer.FakeSummaryWriter.uninstall()\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/hooks/fake_summary_writer.py",
    "content": "# Copyright 2015 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Fake summary writer for unit tests.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom tensorflow.core.framework import summary_pb2\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.summary.writer import writer\nfrom tensorflow.python.summary.writer import writer_cache\n\n\n# TODO(ptucker): Replace with mock framework.\nclass FakeSummaryWriter(object):\n  \"\"\"Fake summary writer.\"\"\"\n\n  _replaced_summary_writer = None\n\n  @classmethod\n  def install(cls):\n    if cls._replaced_summary_writer:\n      raise ValueError('FakeSummaryWriter already installed.')\n    cls._replaced_summary_writer = writer.FileWriter\n    writer.FileWriter = FakeSummaryWriter\n    writer_cache.FileWriter = FakeSummaryWriter\n\n  @classmethod\n  def uninstall(cls):\n    if not cls._replaced_summary_writer:\n      raise ValueError('FakeSummaryWriter not installed.')\n    writer.FileWriter = cls._replaced_summary_writer\n    writer_cache.FileWriter = cls._replaced_summary_writer\n    cls._replaced_summary_writer = None\n\n  def __init__(self, logdir, graph=None):\n    self._logdir = logdir\n    self._graph = graph\n    self._summaries = {}\n    self._added_graphs = []\n    self._added_meta_graphs = []\n    self._added_session_logs = []\n    self._added_run_metadata = {}\n\n  @property\n  def summaries(self):\n    return self._summaries\n\n  def assert_summaries(self,\n                       test_case,\n                       expected_logdir=None,\n                       expected_graph=None,\n                       expected_summaries=None,\n                       expected_added_graphs=None,\n                       expected_added_meta_graphs=None,\n                       expected_session_logs=None):\n    \"\"\"Assert expected items have been added to summary writer.\"\"\"\n    if expected_logdir is not None:\n      test_case.assertEqual(expected_logdir, self._logdir)\n    if expected_graph is not None:\n      test_case.assertTrue(expected_graph is self._graph)\n    expected_summaries = expected_summaries or {}\n    for step in expected_summaries:\n      test_case.assertTrue(\n          step in self._summaries,\n          msg='Missing step %s from %s.' % (step, self._summaries.keys()))\n      actual_simple_values = {}\n      for step_summary in self._summaries[step]:\n        for v in step_summary.value:\n          # Ignore global_step/sec since it's written by Supervisor in a\n          # separate thread, so it's non-deterministic how many get written.\n          if 'global_step/sec' != v.tag:\n            actual_simple_values[v.tag] = v.simple_value\n      test_case.assertEqual(expected_summaries[step], actual_simple_values)\n    if expected_added_graphs is not None:\n      test_case.assertEqual(expected_added_graphs, self._added_graphs)\n    if expected_added_meta_graphs is not None:\n      test_case.assertEqual(\n          len(expected_added_meta_graphs), len(self._added_meta_graphs))\n      for expected, actual in zip(expected_added_meta_graphs,\n                                  self._added_meta_graphs):\n        test_util.assert_meta_graph_protos_equal(test_case, expected, actual)\n    if expected_session_logs is not None:\n      test_case.assertEqual(expected_session_logs, self._added_session_logs)\n\n  def add_summary(self, summ, current_global_step):\n    \"\"\"Add summary.\"\"\"\n    if isinstance(summ, bytes):\n      summary_proto = summary_pb2.Summary()\n      summary_proto.ParseFromString(summ)\n      summ = summary_proto\n    if current_global_step in self._summaries:\n      step_summaries = self._summaries[current_global_step]\n    else:\n      step_summaries = []\n      self._summaries[current_global_step] = step_summaries\n    step_summaries.append(summ)\n\n  # NOTE: Ignore global_step since its value is non-deterministic.\n  def add_graph(self, graph, global_step=None, graph_def=None):\n    \"\"\"Add graph.\"\"\"\n    if (global_step is not None) and (global_step < 0):\n      raise ValueError('Invalid global_step %s.' % global_step)\n    if graph_def is not None:\n      raise ValueError('Unexpected graph_def %s.' % graph_def)\n    self._added_graphs.append(graph)\n\n  def add_meta_graph(self, meta_graph_def, global_step=None):\n    \"\"\"Add metagraph.\"\"\"\n    if (global_step is not None) and (global_step < 0):\n      raise ValueError('Invalid global_step %s.' % global_step)\n    self._added_meta_graphs.append(meta_graph_def)\n\n  # NOTE: Ignore global_step since its value is non-deterministic.\n  def add_session_log(self, session_log, global_step=None):\n    # pylint: disable=unused-argument\n    self._added_session_logs.append(session_log)\n\n  def add_run_metadata(self, run_metadata, tag, global_step=None):\n    if (global_step is not None) and (global_step < 0):\n      raise ValueError('Invalid global_step %s.' % global_step)\n    self._added_run_metadata[tag] = run_metadata\n\n  def flush(self):\n    pass\n\n  def reopen(self):\n    pass\n\n  def close(self):\n    pass\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/hooks/hooks.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Some useful session run hooks.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport time\nimport tensorflow as tf\nfrom tensorflow.python.training import training_util\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\n\n# pylint: disable=protected-access\n@estimator_export('estimator.experimental.InMemoryEvaluatorHook')\nclass InMemoryEvaluatorHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Hook to run evaluation in training without a checkpoint.\n\n  Example:\n\n  ```python\n  def train_input_fn():\n    ...\n    return train_dataset\n\n  def eval_input_fn():\n    ...\n    return eval_dataset\n\n  estimator = tf.estimator.DNNClassifier(...)\n\n  evaluator = tf.estimator.experimental.InMemoryEvaluatorHook(\n      estimator, eval_input_fn)\n  estimator.train(train_input_fn, hooks=[evaluator])\n  ```\n\n  Current limitations of this approach are:\n\n  * It doesn't support multi-node distributed mode.\n  * It doesn't support saveable objects other than variables (such as boosted\n    tree support)\n  * It doesn't support custom saver logic (such as ExponentialMovingAverage\n    support)\n\n  \"\"\"\n\n  def __init__(self,\n               estimator,\n               input_fn,\n               steps=None,\n               hooks=None,\n               name=None,\n               every_n_iter=100):\n    \"\"\"Initializes a `InMemoryEvaluatorHook`.\n\n    Args:\n      estimator: A `tf.estimator.Estimator` instance to call evaluate.\n      input_fn:  Equivalent to the `input_fn` arg to `estimator.evaluate`. A\n        function that constructs the input data for evaluation. See [Creating\n        input functions](\n        https://tensorflow.org/guide/premade_estimators#create_input_functions)\n          for more information. The function should construct and return one of\n        the following:\n          * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a\n            tuple (features, labels) with same constraints as below.\n          * A tuple (features, labels): Where `features` is a `Tensor` or a\n            dictionary of string feature name to `Tensor` and `labels` is a\n            `Tensor` or a dictionary of string label name to `Tensor`. Both\n            `features` and `labels` are consumed by `model_fn`. They should\n            satisfy the expectation of `model_fn` from inputs.\n      steps: Equivalent to the `steps` arg to `estimator.evaluate`.  Number of\n        steps for which to evaluate model. If `None`, evaluates until `input_fn`\n        raises an end-of-input exception.\n      hooks: Equivalent to the `hooks` arg to `estimator.evaluate`. List of\n        `SessionRunHook` subclass instances. Used for callbacks inside the\n        evaluation call.\n      name:  Equivalent to the `name` arg to `estimator.evaluate`. Name of the\n        evaluation if user needs to run multiple evaluations on different data\n        sets, such as on training data vs test data. Metrics for different\n        evaluations are saved in separate folders, and appear separately in\n        tensorboard.\n      every_n_iter: `int`, runs the evaluator once every N training iteration.\n\n    Raises:\n      ValueError: if `every_n_iter` is non-positive or it's not a single machine\n        training\n    \"\"\"\n    if every_n_iter is None or every_n_iter <= 0:\n      raise ValueError('invalid every_n_iter=%s.' % every_n_iter)\n    if (estimator.config.num_ps_replicas > 0 or\n        estimator.config.num_worker_replicas > 1):\n      raise ValueError(\n          'InMemoryEvaluator supports only single machine (aka Local) setting.')\n    self._estimator = estimator\n    self._input_fn = input_fn\n    self._steps = steps\n    self._name = name\n    self._every_n_iter = every_n_iter\n    self._eval_dir = os.path.join(self._estimator.model_dir,\n                                  'eval' if not name else 'eval_' + name)\n\n    self._graph = None\n    self._hooks = estimator_lib._check_hooks_type(hooks)\n    self._hooks.extend(self._estimator._convert_eval_steps_to_hooks(steps))\n    self._timer = tf.compat.v1.train.SecondOrStepTimer(every_steps=every_n_iter)\n\n  def begin(self):\n    \"\"\"Build eval graph and restoring op.\"\"\"\n    self._timer.reset()\n    self._iter_count = 0\n    self._graph = tf.Graph()\n    with self._graph.as_default():\n      (self._scaffold, self._update_op, self._eval_dict,\n       self._all_hooks) = self._estimator._evaluate_build_graph(\n           self._input_fn, self._hooks, checkpoint_path=None)\n\n      if self._scaffold.saver is not None:\n        raise ValueError('InMemoryEvaluator does not support custom saver')\n      if self._scaffold.init_fn is not None:\n        raise ValueError('InMemoryEvaluator does not support custom init_fn')\n\n      self._var_name_to_eval_var = {\n          v.name: v for v in tf.compat.v1.get_collection(\n              tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)\n      }\n      self._var_name_to_placeholder = {\n          v.name: tf.compat.v1.placeholder(v.dtype) for v in\n          tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)\n      }\n\n  def after_create_session(self, session, coord):  # pylint: disable=unused-argument\n    \"\"\"Does first run which shows the eval metrics before training.\"\"\"\n    if tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.SAVEABLE_OBJECTS):\n      raise ValueError(\n          'InMemoryEvaluator does not support saveables other than global '\n          'variables.')\n    self._var_name_to_train_var = {\n        v.name: v for v in tf.compat.v1.get_collection(\n            tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)\n    }\n    var_names_to_transfer = set(self._var_name_to_placeholder.keys()) & set(\n        self._var_name_to_train_var.keys())\n    # Filter training var names that are not exist in evaluation\n    self._var_name_to_train_var = {\n        v_name: self._var_name_to_train_var[v_name]\n        for v_name in var_names_to_transfer\n    }\n    # Filter eval var names that are not exist in training\n    self._var_name_to_eval_var = {\n        v_name: self._var_name_to_eval_var[v_name]\n        for v_name in var_names_to_transfer\n    }\n\n    with self._graph.as_default():\n      self._var_feed_op = tf.group([\n          tf.compat.v1.assign(self._var_name_to_eval_var[v_name],\n                              self._var_name_to_placeholder[v_name])\n          for v_name in var_names_to_transfer\n      ])\n\n    self._evaluate(session)\n\n  def _evaluate(self, train_session):\n    var_name_to_value = train_session.run(self._var_name_to_train_var)\n    placeholder_to_value = {\n        self._var_name_to_placeholder[v_name]: var_name_to_value[v_name]\n        for v_name in var_name_to_value\n    }\n\n    def feed_variables(scaffold, session):\n      del scaffold\n      session.run(self._var_feed_op, feed_dict=placeholder_to_value)\n\n    scaffold = tf.compat.v1.train.Scaffold(\n        init_fn=feed_variables, copy_from_scaffold=self._scaffold)\n\n    with self._graph.as_default():\n      self._estimator._evaluate_run(\n          checkpoint_path=None,\n          scaffold=scaffold,\n          update_op=self._update_op,\n          eval_dict=self._eval_dict,\n          all_hooks=self._all_hooks,\n          output_dir=self._eval_dir)\n\n    self._timer.update_last_triggered_step(self._iter_count)\n\n  def after_run(self, run_context, run_values):  # pylint: disable=unused-argument\n    \"\"\"Runs evaluator.\"\"\"\n    self._iter_count += 1\n    if self._timer.should_trigger_for_step(self._iter_count):\n      self._evaluate(run_context.session)\n\n  def end(self, session):  # pylint: disable=unused-argument\n    \"\"\"Runs evaluator for final model.\"\"\"\n    self._evaluate(session)\n\n\nclass _StopAtCheckpointStepHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Hook that requests stop at a specified step based on checkpoint.\n\n  Note: We recommend using 'make_stop_at_checkpoint_step_hook` to get the proper\n  hook.\n  \"\"\"\n\n  def __init__(self, model_dir, last_step, wait_after_file_check_secs=30):\n    \"\"\"Initializes a `StopAtCheckpointStepHook`.\n\n    This hook requests stop after a last step has been reached. It checks latest\n    checkpoint to verify last step is written on disk or not.\n\n    Args:\n      model_dir: Directory to read global step from latest checkpoint.\n      last_step: Step after which to stop.\n      wait_after_file_check_secs: Reading same file by many workers may create\n        I/O issues. To throttle that we will wait given secs after each read of\n        the file.\n\n    Raises:\n      ValueError: If one of the arguments is invalid.\n    \"\"\"\n    if last_step is None:\n      raise ValueError('last_step must be specified.')\n    if model_dir is None:\n      raise ValueError('model_dir must be specified.')\n\n    self._model_dir = model_dir\n    self._last_step = last_step\n    self._wait_after_file_check_secs = wait_after_file_check_secs\n\n  def begin(self):\n    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access\n    if self._global_step_tensor is None:\n      raise RuntimeError(\n          'Global step should be created to use StopAtCheckpointStepHook.')\n\n  def before_run(self, run_context):  # pylint: disable=unused-argument\n    return tf.compat.v1.train.SessionRunArgs(self._global_step_tensor)\n\n  def after_run(self, run_context, run_values):\n    global_step = run_values.results + 1\n    if global_step >= self._last_step:\n      # Check latest global step in the checkpoint to ensure that the targeted\n      # last step is written on disk.\n\n      step = estimator_lib._load_global_step_from_checkpoint_dir(\n          self._model_dir)\n      if step >= self._last_step:\n        run_context.request_stop()\n      else:\n        time.sleep(self._wait_after_file_check_secs)\n\n\n@estimator_export('estimator.experimental.make_stop_at_checkpoint_step_hook')\ndef make_stop_at_checkpoint_step_hook(estimator,\n                                      last_step,\n                                      wait_after_file_check_secs=30):\n  \"\"\"Creates a proper StopAtCheckpointStepHook based on chief status.\"\"\"\n\n  if estimator.config.is_chief:\n    return tf.compat.v1.train.StopAtStepHook(last_step=last_step)\n  return _StopAtCheckpointStepHook(\n      model_dir=estimator.model_dir,\n      last_step=last_step,\n      wait_after_file_check_secs=wait_after_file_check_secs)\n\n\n# pylint: enable=protected-access\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/hooks/hooks_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for hooks.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport glob\nimport json\nimport os\nimport tempfile\nimport time\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import estimator_lib\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.hooks import hooks as hooks_lib\n\n\ndef summary_step_keyword_to_value_mapping(dir_):\n  tf.compat.v1.summary.FileWriterCache.clear()\n\n  # Get last Event written.\n  event_paths = glob.glob(os.path.join(dir_, 'events*'))\n  step_keyword_to_value = {}\n  for last_event in tf.compat.v1.train.summary_iterator(event_paths[-1]):\n    if last_event.step not in step_keyword_to_value:\n      step_keyword_to_value[last_event.step] = {}\n    if last_event.summary is not None:\n      for value in last_event.summary.value:\n        step_keyword_to_value[last_event.step][value.tag] = value.simple_value\n\n  return step_keyword_to_value\n\n\ndef get_summary_value(dir_, step, keyword):\n  \"\"\"Get summary value for given step and keyword.\"\"\"\n\n  tf.compat.v1.summary.FileWriterCache.clear()\n  # Get last Event written.\n  event_paths = glob.glob(os.path.join(dir_, 'events*'))\n  print('XXX', event_paths)\n  for last_event in tf.compat.v1.train.summary_iterator(event_paths[-1]):\n    if last_event.step == step and last_event.summary is not None:\n      for value in last_event.summary.value:\n        if keyword in value.tag:\n          return value.simple_value\n  return None\n\n\n@test_util.deprecated_graph_mode_only\nclass InMemoryEvaluatorHookTest(tf.test.TestCase):\n\n  def test_runs_eval_metrics(self):\n\n    def model_fn(features, labels, mode):\n      _ = labels\n      if estimator_lib.ModeKeys.TRAIN == mode:\n        with tf.control_dependencies([features]):\n          train_op = tf.compat.v1.assign_add(\n              tf.compat.v1.train.get_global_step(), 1)\n        return estimator_lib.EstimatorSpec(\n            mode, loss=tf.constant(3.), train_op=train_op)\n      if estimator_lib.ModeKeys.EVAL == mode:\n        mean = tf_keras.metrics.Mean()\n        mean.update_state(features)\n        return estimator_lib.EstimatorSpec(\n            mode,\n            loss=tf.constant(5.),\n            eval_metric_ops={\n                'mean_of_features': mean,\n            })\n\n    estimator = estimator_lib.Estimator(model_fn=model_fn)\n\n    def input_fn():\n      return tf.compat.v1.data.Dataset.range(10)\n\n    evaluator = hooks_lib.InMemoryEvaluatorHook(\n        estimator, input_fn, every_n_iter=4)\n    estimator.train(input_fn, hooks=[evaluator])\n\n    self.assertTrue(os.path.isdir(estimator.eval_dir()))\n    step_keyword_to_value = summary_step_keyword_to_value_mapping(\n        estimator.eval_dir())\n\n    # 4.5 = sum(range(10))/10\n    # before training\n    self.assertEqual(4.5, step_keyword_to_value[0]['mean_of_features'])\n    # intervals (every_n_iter=4)\n    self.assertEqual(4.5, step_keyword_to_value[4]['mean_of_features'])\n    self.assertEqual(4.5, step_keyword_to_value[8]['mean_of_features'])\n    # end\n    self.assertEqual(4.5, step_keyword_to_value[10]['mean_of_features'])\n    self.assertEqual(set([0, 4, 8, 10]), set(step_keyword_to_value.keys()))\n\n  def test_uses_latest_variable_value(self):\n\n    def model_fn(features, labels, mode):\n      _ = labels\n      step = tf.compat.v1.train.get_global_step()\n      w = tf.compat.v1.get_variable(\n          'w',\n          shape=[],\n          initializer=tf.compat.v1.initializers.zeros(),\n          dtype=tf.dtypes.int64)\n      if estimator_lib.ModeKeys.TRAIN == mode:\n        # to consume features, we have control dependency\n        with tf.control_dependencies([features]):\n          step_inc = tf.compat.v1.assign_add(\n              tf.compat.v1.train.get_global_step(), 1)\n        with tf.control_dependencies([step_inc]):\n          assign_w_to_step_plus_2 = w.assign(step + 2)\n        return estimator_lib.EstimatorSpec(\n            mode, loss=tf.constant(3.), train_op=assign_w_to_step_plus_2)\n      if estimator_lib.ModeKeys.EVAL == mode:\n        # to consume features, we have control dependency\n        with tf.control_dependencies([features]):\n          loss = tf.constant(5.)\n        mean = tf_keras.metrics.Mean()\n        mean.update_state(w)\n        return estimator_lib.EstimatorSpec(\n            mode,\n            loss=loss,\n            # w is constant in each step, so the mean.\n            # w = 0 if step==0 else step+2\n            eval_metric_ops={'mean_of_const': mean})\n\n    estimator = estimator_lib.Estimator(model_fn=model_fn)\n\n    def input_fn():\n      return tf.compat.v1.data.Dataset.range(10)\n\n    evaluator = hooks_lib.InMemoryEvaluatorHook(\n        estimator, input_fn, every_n_iter=4)\n    estimator.train(input_fn, hooks=[evaluator])\n\n    self.assertTrue(os.path.isdir(estimator.eval_dir()))\n    step_keyword_to_value = summary_step_keyword_to_value_mapping(\n        estimator.eval_dir())\n    # w = 0 if step==0 else step+2\n    self.assertEqual(0, step_keyword_to_value[0]['mean_of_const'])\n    self.assertEqual(6, step_keyword_to_value[4]['mean_of_const'])\n    self.assertEqual(12, step_keyword_to_value[10]['mean_of_const'])\n\n  def test_dnn_classifier(self):\n    embedding = tf.feature_column.embedding_column(\n        tf.feature_column.categorical_column_with_vocabulary_list(\n            'wire_cast', ['kima', 'omar', 'stringer']), 8)\n    dnn = estimator_lib.DNNClassifier(\n        feature_columns=[embedding], hidden_units=[3, 1])\n\n    def train_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensors(({\n          'wire_cast': [['omar'], ['kima']]\n      }, [[0], [1]])).repeat(3)\n\n    def eval_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensors(({\n          'wire_cast': [['stringer'], ['kima']]\n      }, [[0], [1]])).repeat(2)\n\n    evaluator = hooks_lib.InMemoryEvaluatorHook(\n        dnn, eval_input_fn, name='in-memory')\n    dnn.train(train_input_fn, hooks=[evaluator])\n    self.assertTrue(os.path.isdir(dnn.eval_dir('in-memory')))\n    step_keyword_to_value = summary_step_keyword_to_value_mapping(\n        dnn.eval_dir('in-memory'))\n\n    final_metrics = dnn.evaluate(eval_input_fn)\n    step = final_metrics[tf.compat.v1.GraphKeys.GLOBAL_STEP]\n    for summary_tag in final_metrics:\n      if summary_tag == tf.compat.v1.GraphKeys.GLOBAL_STEP:\n        continue\n      self.assertEqual(final_metrics[summary_tag],\n                       step_keyword_to_value[step][summary_tag])\n\n  def test_raise_error_with_multi_worker(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0\n        }\n    }\n    with tf.compat.v1.test.mock.patch.dict(\n        'os.environ', {'TF_CONFIG': json.dumps(tf_config)}):\n      dnn = estimator_lib.DNNClassifier(\n          feature_columns=[tf.feature_column.numeric_column('x')],\n          hidden_units=[3, 1])\n\n    def eval_input_fn():\n      pass\n\n    with self.assertRaisesRegexp(ValueError, 'supports only single machine'):\n      hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn)\n\n  def test_raise_error_with_ps(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1'],\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0\n        }\n    }\n    with tf.compat.v1.test.mock.patch.dict(\n        'os.environ', {'TF_CONFIG': json.dumps(tf_config)}):\n      dnn = estimator_lib.DNNClassifier(\n          feature_columns=[tf.feature_column.numeric_column('x')],\n          hidden_units=[3, 1])\n\n    def eval_input_fn():\n      pass\n\n    with self.assertRaisesRegexp(ValueError, 'supports only single machine'):\n      hooks_lib.InMemoryEvaluatorHook(dnn, eval_input_fn)\n\n  def test_raise_error_with_custom_saver_in_eval(self):\n\n    def model_fn(features, labels, mode):\n      _, _ = features, labels\n      mean = tf_keras.metrics.Mean()\n      mean.update_state(tf.constant(2.))\n      return estimator_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(3.),\n          scaffold=tf.compat.v1.train.Scaffold(\n              saver=tf.compat.v1.train.Saver()),\n          train_op=tf.constant(5.),\n          eval_metric_ops={\n              'mean_of_features': mean,\n          })\n\n    estimator = estimator_lib.Estimator(model_fn=model_fn)\n\n    def input_fn():\n      return tf.compat.v1.data.Dataset.range(10)\n\n    evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn)\n    with self.assertRaisesRegexp(ValueError, 'does not support custom saver'):\n      evaluator.begin()\n\n  def test_raise_error_with_custom_init_fn_in_eval(self):\n\n    def model_fn(features, labels, mode):\n      _, _ = features, labels\n\n      def init_fn(scaffold, session):\n        _, _ = scaffold, session\n\n      mean = tf_keras.metrics.Mean()\n      mean.update_state(tf.constant(2.))\n      return estimator_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(3.),\n          scaffold=tf.compat.v1.train.Scaffold(init_fn=init_fn),\n          train_op=tf.constant(5.),\n          eval_metric_ops={\n              'mean_of_features': mean,\n          })\n\n    estimator = estimator_lib.Estimator(model_fn=model_fn)\n\n    def input_fn():\n      return tf.compat.v1.data.Dataset.range(10)\n\n    evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn)\n    with self.assertRaisesRegexp(ValueError, 'does not support custom init_fn'):\n      evaluator.begin()\n\n  def test_raise_error_with_saveables_other_than_global_variables(self):\n\n    def model_fn(features, labels, mode):\n      _, _ = features, labels\n      w = tf.compat.v1.Variable(\n          initial_value=[0.],\n          trainable=False,\n          collections=[tf.compat.v1.GraphKeys.SAVEABLE_OBJECTS])\n      init_op = tf.group(\n          [w.initializer,\n           tf.compat.v1.train.get_global_step().initializer])\n\n      mean = tf_keras.metrics.Mean()\n      mean.update_state(tf.constant(2.))\n      return estimator_lib.EstimatorSpec(\n          mode,\n          loss=tf.constant(3.),\n          scaffold=tf.compat.v1.train.Scaffold(init_op=init_op),\n          train_op=tf.constant(5.),\n          eval_metric_ops={\n              'mean_of_features': mean,\n          })\n\n    estimator = estimator_lib.Estimator(model_fn=model_fn)\n\n    def input_fn():\n      return tf.compat.v1.data.Dataset.range(10)\n\n    evaluator = hooks_lib.InMemoryEvaluatorHook(estimator, input_fn)\n    with self.assertRaisesRegexp(ValueError, 'does not support saveables'):\n      estimator.train(input_fn, hooks=[evaluator])\n\n\n@test_util.deprecated_graph_mode_only\nclass StopAtCheckpointStepHookTest(tf.test.TestCase):\n\n  def test_do_not_stop_if_checkpoint_is_not_there(self):\n    with tf.Graph().as_default():\n      step = tf.compat.v1.train.create_global_step()\n      assign_ten = step.assign(10)\n      no_op = tf.no_op()\n      hook = hooks_lib._StopAtCheckpointStepHook(\n          model_dir=tempfile.mkdtemp(), last_step=10)\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook]) as mon_sess:\n        mon_sess.raw_session().run(assign_ten)\n        with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:\n          mon_sess.run(no_op)\n          self.assertTrue(mock_sleep.called)\n        self.assertFalse(mon_sess.should_stop())\n\n  def test_do_not_stop_if_checkpoint_step_is_smaller(self):\n    model_dir = tempfile.mkdtemp()\n    with tf.Graph().as_default():\n      step = tf.compat.v1.train.create_global_step()\n      assign_nine = step.assign(9)\n      assign_ten = step.assign(10)\n      no_op = tf.no_op()\n      hook = hooks_lib._StopAtCheckpointStepHook(\n          model_dir=model_dir, last_step=10)\n      with tf.compat.v1.Session() as sess:\n        sess.run(assign_nine)\n        tf.compat.v1.train.Saver().save(sess,\n                                        os.path.join(model_dir, 'model.ckpt'))\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook]) as mon_sess:\n        mon_sess.raw_session().run(assign_ten)\n        with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:\n          mon_sess.run(no_op)\n          self.assertTrue(mock_sleep.called)\n        self.assertFalse(mon_sess.should_stop())\n\n  def test_stop_if_checkpoint_step_is_laststep(self):\n    model_dir = tempfile.mkdtemp()\n    with tf.Graph().as_default():\n      step = tf.compat.v1.train.create_global_step()\n      assign_ten = step.assign(10)\n      no_op = tf.no_op()\n      hook = hooks_lib._StopAtCheckpointStepHook(\n          model_dir=model_dir, last_step=10)\n      with tf.compat.v1.Session() as sess:\n        sess.run(assign_ten)\n        tf.compat.v1.train.Saver().save(sess,\n                                        os.path.join(model_dir, 'model.ckpt'))\n      with tf.compat.v1.train.SingularMonitoredSession(\n          hooks=[hook]) as mon_sess:\n        mon_sess.raw_session().run(assign_ten)\n        with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:\n          mon_sess.run(no_op)\n          self.assertFalse(mock_sleep.called)\n        self.assertTrue(mon_sess.should_stop())\n\n  def test_creates_regular_stop_at_step_hook_for_chief(self):\n    # by default an estimator is in chief mode\n    dnn = estimator_lib.DNNClassifier(\n        feature_columns=[tf.feature_column.numeric_column('x')],\n        hidden_units=[3, 1])\n    hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)\n    self.assertIsInstance(hook, tf.compat.v1.train.StopAtStepHook)\n    self.assertEqual(300, hook._last_step)\n\n  def test_creates_checkpoint_hook_for_workers(self):\n\n    class FakeWorkerConfig(estimator_lib.RunConfig):\n\n      @property\n      def is_chief(self):\n        return False\n\n    dnn = estimator_lib.DNNClassifier(\n        feature_columns=[tf.feature_column.numeric_column('x')],\n        hidden_units=[3, 1],\n        config=FakeWorkerConfig())\n    hook = hooks_lib.make_stop_at_checkpoint_step_hook(dnn, 300)\n    self.assertIsInstance(hook, hooks_lib._StopAtCheckpointStepHook)\n    self.assertEqual(300, hook._last_step)\n    self.assertEqual(dnn.model_dir, hook._model_dir)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/hooks/session_run_hook.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"A SessionRunHook extends `session.run()` calls for the `MonitoredSession`.\n\nSessionRunHooks are useful to track training, report progress, request early\nstopping and more. SessionRunHooks use the observer pattern and notify at the\nfollowing points:\n - when a session starts being used\n - before a call to the `session.run()`\n - after a call to the `session.run()`\n - when the session closed\n\nA SessionRunHook encapsulates a piece of reusable/composable computation that\ncan piggyback a call to `MonitoredSession.run()`. A hook can add any\nops-or-tensor/feeds to the run call, and when the run call finishes with success\ngets the outputs it requested. Hooks are allowed to add ops to the graph in\n`hook.begin()`. The graph is finalized after the `begin()` method is called.\n\nThere are a few pre-defined hooks:\n - StopAtStepHook: Request stop based on global_step\n - CheckpointSaverHook: saves checkpoint\n - LoggingTensorHook: outputs one or more tensor values to log\n - NanTensorHook: Request stop if given `Tensor` contains Nans.\n - SummarySaverHook: saves summaries to a summary writer\n\nFor more specific needs, you can create custom hooks:\n  class ExampleHook(SessionRunHook):\n    def begin(self):\n      # You can add ops to the graph here.\n      print('Starting the session.')\n      self.your_tensor = ...\n\n    def after_create_session(self, session, coord):\n      # When this is called, the graph is finalized and\n      # ops can no longer be added to the graph.\n      print('Session created.')\n\n    def before_run(self, run_context):\n      print('Before calling session.run().')\n      return SessionRunArgs(self.your_tensor)\n\n    def after_run(self, run_context, run_values):\n      print('Done running one step. The value of my tensor: %s',\n            run_values.results)\n      if you-need-to-stop-loop:\n        run_context.request_stop()\n\n    def end(self, session):\n      print('Done with the session.')\n\nTo understand how hooks interact with calls to `MonitoredSession.run()`,\nlook at following code:\n  with MonitoredTrainingSession(hooks=your_hooks, ...) as sess:\n    while not sess.should_stop():\n      sess.run(your_fetches)\n\nAbove user code leads to following execution:\n  call hooks.begin()\n  sess = tf.Session()\n  call hooks.after_create_session()\n  while not stop is requested:\n    call hooks.before_run()\n    try:\n      results = sess.run(merged_fetches, feed_dict=merged_feeds)\n    except (errors.OutOfRangeError, StopIteration):\n      break\n    call hooks.after_run()\n  call hooks.end()\n  sess.close()\n\nNote that if sess.run() raises OutOfRangeError or StopIteration then\nhooks.after_run() will not be called but hooks.end() will still be called.\nIf sess.run() raises any other exception then neither hooks.after_run() nor\nhooks.end() will be called.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nfrom tensorflow.python.training.session_run_hook import SessionRunArgs\nfrom tensorflow.python.training.session_run_hook import SessionRunContext\nfrom tensorflow.python.training.session_run_hook import SessionRunHook\nfrom tensorflow.python.training.session_run_hook import SessionRunValues\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\nestimator_export(\"estimator.SessionRunHook\")(SessionRunHook)\nestimator_export(\"estimator.SessionRunArgs\")(SessionRunArgs)\nestimator_export(\"estimator.SessionRunContext\")(SessionRunContext)\nestimator_export(\"estimator.SessionRunValues\")(SessionRunValues)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/__init__.py",
    "content": ""
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/inputs.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utility methods to create simple input_fns.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n# pylint: disable=unused-import,line-too-long\nfrom tensorflow_estimator.python.estimator.inputs.numpy_io import numpy_input_fn\nfrom tensorflow_estimator.python.estimator.inputs.pandas_io import pandas_input_fn\n\n# pylint: enable=unused-import,line-too-long\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/numpy_io.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Methods to allow dict of numpy arrays.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n\nimport numpy as np\nfrom six import string_types\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.inputs.queues import feeding_functions\n\n# Key name to pack the target into dict of `features`. See\n# `_get_unique_target_key` for details.\n_TARGET_KEY = '__target_key__'\n\n\ndef _get_unique_target_key(features):\n  \"\"\"Returns a key not existed in the input dict `features`.\n\n  Caller of `input_fn` usually provides `features` (dict of numpy arrays) and\n  `target`, but the underlying feeding module expects a single dict of numpy\n  arrays as input. So, the `target` needs to be packed into the `features`\n  temporarily and unpacked after calling the feeding function. Toward this goal,\n  this function returns a key not existed in the `features` to pack the\n  `target`.\n\n  Args:\n    features: OrderedDict of numpy arrays\n\n  Returns:\n    A unique key that can be used to insert the subsequent target into\n      features dict.\n  \"\"\"\n  target_key = _TARGET_KEY\n  while target_key in features:\n    target_key += '_n'\n  return target_key\n\n\ndef _validate_and_convert_features(x):\n  \"\"\"Type check input data and make a shadow copy as an ordered dict.\n\n  Args:\n    x: numpy array object or dict of numpy array objects. If an array, the array\n      will be treated as a single feature.\n\n  Returns:\n    OrderedDict copy of x.\n\n  Raises:\n    ValueError: if x is empty\n    TypeError: if x is an unknown type.\n  \"\"\"\n  if isinstance(x, dict):\n    if not x:\n      raise ValueError('x cannot be an empty dict')\n    # Make a shadow copy and also ensure the order of iteration is consistent.\n    ordered_dict_data = collections.OrderedDict(\n        sorted(x.items(), key=lambda t: t[0]))\n  elif isinstance(x, np.ndarray):\n    if x.size == 0:\n      raise ValueError('x cannot be an empty array')\n\n    # Make a shadow copy and convert to dict to align with dict processing.\n    ordered_dict_data = collections.OrderedDict({'__direct_np_input__': x})\n  else:\n    x_type = type(x).__name__\n    raise TypeError('x must be a dict or array; got {}'.format(x_type))\n\n  return ordered_dict_data\n\n\n@estimator_export(v1=['estimator.inputs.numpy_input_fn'])\ndef numpy_input_fn(x,\n                   y=None,\n                   batch_size=128,\n                   num_epochs=1,\n                   shuffle=None,\n                   queue_capacity=1000,\n                   num_threads=1):\n  \"\"\"Returns input function that would feed dict of numpy arrays into the model.\n\n  This returns a function outputting `features` and `targets` based on the dict\n  of numpy arrays. The dict `features` has the same keys as the `x`. The dict\n  `targets` has the same keys as the `y` if `y` is a dict.\n\n  Example:\n\n  ```python\n  age = np.arange(4) * 1.0\n  height = np.arange(32, 36)\n  x = {'age': age, 'height': height}\n  y = np.arange(-32, -28)\n\n  with tf.Session() as session:\n    input_fn = numpy_io.numpy_input_fn(\n        x, y, batch_size=2, shuffle=False, num_epochs=1)\n  ```\n\n  Args:\n    x: numpy array object or dict of numpy array objects. If an array, the array\n      will be treated as a single feature.\n    y: numpy array object or dict of numpy array object. `None` if absent.\n    batch_size: Integer, size of batches to return.\n    num_epochs: Integer, number of epochs to iterate over data. If `None` will\n      run forever.\n    shuffle: Boolean, if True shuffles the queue. Avoid shuffle at prediction\n      time.\n    queue_capacity: Integer, size of queue to accumulate.\n    num_threads: Integer, number of threads used for reading and enqueueing. In\n      order to have predicted and repeatable order of reading and enqueueing,\n      such as in prediction and evaluation mode, `num_threads` should be 1.\n\n  Returns:\n    Function, that has signature of ()->(dict of `features`, `targets`)\n\n  Raises:\n    ValueError: if the shape of `y` mismatches the shape of values in `x` (i.e.,\n      values in `x` have same shape).\n    ValueError: if duplicate keys are in both `x` and `y` when `y` is a dict.\n    ValueError: if x or y is an empty dict.\n    TypeError: `x` is not a dict or array.\n    ValueError: if 'shuffle' is not provided or a bool.\n  \"\"\"\n  if not isinstance(shuffle, bool):\n    raise ValueError('shuffle must be provided and explicitly set as boolean '\n                     '(it is recommended to set it as True for training); '\n                     'got {}'.format(shuffle))\n\n  def input_fn():\n    \"\"\"Numpy input function.\"\"\"\n\n    # Note that `x` should not be used after conversion to ordered_dict_data,\n    # as type could be either dict or array.\n    ordered_dict_data = _validate_and_convert_features(x)\n\n    # Deep copy keys which is a view in python 3\n    feature_keys = list(ordered_dict_data.keys())\n\n    if y is None:\n      target_keys = None\n    elif isinstance(y, dict):\n      if not y:\n        raise ValueError('y cannot be empty dict, use None instead.')\n\n      ordered_dict_y = collections.OrderedDict(\n          sorted(y.items(), key=lambda t: t[0]))\n      target_keys = list(ordered_dict_y.keys())\n\n      duplicate_keys = set(feature_keys).intersection(set(target_keys))\n      if duplicate_keys:\n        raise ValueError('{} duplicate keys are found in both x and y: '\n                         '{}'.format(len(duplicate_keys), duplicate_keys))\n\n      ordered_dict_data.update(ordered_dict_y)\n    else:\n      target_keys = _get_unique_target_key(ordered_dict_data)\n      ordered_dict_data[target_keys] = y\n\n    if len(set(v.shape[0] for v in ordered_dict_data.values())) != 1:\n      shape_dict_of_x = {k: ordered_dict_data[k].shape for k in feature_keys}\n\n      if target_keys is None:\n        shape_of_y = None\n      elif isinstance(target_keys, string_types):\n        shape_of_y = y.shape\n      else:\n        shape_of_y = {k: ordered_dict_data[k].shape for k in target_keys}\n\n      raise ValueError('Length of tensors in x and y is mismatched. All '\n                       'elements in x and y must have the same length.\\n'\n                       'Shapes in x: {}\\n'\n                       'Shapes in y: {}\\n'.format(shape_dict_of_x, shape_of_y))\n\n    queue = feeding_functions._enqueue_data(  # pylint: disable=protected-access\n        ordered_dict_data,\n        queue_capacity,\n        shuffle=shuffle,\n        num_threads=num_threads,\n        enqueue_size=batch_size,\n        num_epochs=num_epochs)\n\n    batch = (\n        queue.dequeue_many(batch_size)\n        if num_epochs is None else queue.dequeue_up_to(batch_size))\n\n    # Remove the first `Tensor` in `batch`, which is the row number.\n    if batch:\n      batch.pop(0)\n\n    if isinstance(x, np.ndarray):\n      # Return as the same type as original array.\n      features = batch[0]\n    else:\n      # Return as the original dict type\n      features = dict(zip(feature_keys, batch[:len(feature_keys)]))\n\n    if target_keys is None:\n      # TODO(martinwicke), return consistent result\n      return features\n    elif isinstance(target_keys, string_types):\n      target = batch[-1]\n      return features, target\n    else:\n      target = dict(zip(target_keys, batch[-len(target_keys):]))\n      return features, target\n\n  return input_fn\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/numpy_io_test.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for numpy_io.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.feature_column.feature_column import _LinearModel\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass NumpyIoTest(tf.test.TestCase):\n\n  def testNumpyInputFn(self):\n    a = np.arange(4) * 1.0\n    b = np.arange(32, 36)\n    x = {'a': a, 'b': b}\n    y = np.arange(-32, -28)\n\n    with self.cached_session() as session:\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n      features, target = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [0, 1])\n      self.assertAllEqual(res[0]['b'], [32, 33])\n      self.assertAllEqual(res[1], [-32, -31])\n\n      session.run([features, target])\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features, target])\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testNumpyInputFnWithVeryLargeBatchSizeAndMultipleEpochs(self):\n    a = np.arange(2) * 1.0\n    b = np.arange(32, 34)\n    x = {'a': a, 'b': b}\n    y = np.arange(-32, -30)\n\n    with self.cached_session() as session:\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=128, shuffle=False, num_epochs=2)\n      features, target = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [0, 1, 0, 1])\n      self.assertAllEqual(res[0]['b'], [32, 33, 32, 33])\n      self.assertAllEqual(res[1], [-32, -31, -32, -31])\n\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features, target])\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testNumpyInputFnWithZeroEpochs(self):\n    a = np.arange(4) * 1.0\n    b = np.arange(32, 36)\n    x = {'a': a, 'b': b}\n    y = np.arange(-32, -28)\n\n    with self.cached_session() as session:\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=0)\n      features, target = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features, target])\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testNumpyInputFnWithBatchSizeNotDividedByDataSize(self):\n    batch_size = 2\n    a = np.arange(5) * 1.0\n    b = np.arange(32, 37)\n    x = {'a': a, 'b': b}\n    y = np.arange(-32, -27)\n\n    with self.cached_session() as session:\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=batch_size, shuffle=False, num_epochs=1)\n      features, target = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [0, 1])\n      self.assertAllEqual(res[0]['b'], [32, 33])\n      self.assertAllEqual(res[1], [-32, -31])\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [2, 3])\n      self.assertAllEqual(res[0]['b'], [34, 35])\n      self.assertAllEqual(res[1], [-30, -29])\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [4])\n      self.assertAllEqual(res[0]['b'], [36])\n      self.assertAllEqual(res[1], [-28])\n\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features, target])\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testNumpyInputFnWithBatchSizeNotDividedByDataSizeAndMultipleEpochs(self):\n    batch_size = 2\n    a = np.arange(3) * 1.0\n    b = np.arange(32, 35)\n    x = {'a': a, 'b': b}\n    y = np.arange(-32, -29)\n\n    with self.cached_session() as session:\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=batch_size, shuffle=False, num_epochs=3)\n      features, target = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [0, 1])\n      self.assertAllEqual(res[0]['b'], [32, 33])\n      self.assertAllEqual(res[1], [-32, -31])\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [2, 0])\n      self.assertAllEqual(res[0]['b'], [34, 32])\n      self.assertAllEqual(res[1], [-30, -32])\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [1, 2])\n      self.assertAllEqual(res[0]['b'], [33, 34])\n      self.assertAllEqual(res[1], [-31, -30])\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [0, 1])\n      self.assertAllEqual(res[0]['b'], [32, 33])\n      self.assertAllEqual(res[1], [-32, -31])\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [2])\n      self.assertAllEqual(res[0]['b'], [34])\n      self.assertAllEqual(res[1], [-30])\n\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features, target])\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testNumpyInputFnWithBatchSizeLargerThanDataSize(self):\n    batch_size = 10\n    a = np.arange(4) * 1.0\n    b = np.arange(32, 36)\n    x = {'a': a, 'b': b}\n    y = np.arange(-32, -28)\n\n    with self.cached_session() as session:\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=batch_size, shuffle=False, num_epochs=1)\n      features, target = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [0, 1, 2, 3])\n      self.assertAllEqual(res[0]['b'], [32, 33, 34, 35])\n      self.assertAllEqual(res[1], [-32, -31, -30, -29])\n\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features, target])\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testNumpyInputFnWithDifferentDimensionsOfFeatures(self):\n    a = np.array([[1, 2], [3, 4]])\n    b = np.array([5, 6])\n    x = {'a': a, 'b': b}\n    y = np.arange(-32, -30)\n\n    with self.cached_session() as session:\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n      features, target = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      res = session.run([features, target])\n      self.assertAllEqual(res[0]['a'], [[1, 2], [3, 4]])\n      self.assertAllEqual(res[0]['b'], [5, 6])\n      self.assertAllEqual(res[1], [-32, -31])\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testNumpyInputFnWithXAsNonDict(self):\n    x = list(range(32, 36))\n    y = np.arange(4)\n    with self.cached_session():\n      with self.assertRaisesRegexp(TypeError, 'x must be a dict or array'):\n        failing_input_fn = numpy_io.numpy_input_fn(\n            x, y, batch_size=2, shuffle=False, num_epochs=1)\n        failing_input_fn()\n\n  def testNumpyInputFnWithXIsEmptyDict(self):\n    x = {}\n    y = np.arange(4)\n    with self.cached_session():\n      with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):\n        failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)\n        failing_input_fn()\n\n  def testNumpyInputFnWithXIsEmptyArray(self):\n    x = np.array([[], []])\n    y = np.arange(4)\n    with self.cached_session():\n      with self.assertRaisesRegexp(ValueError, 'x cannot be an empty'):\n        failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)\n        failing_input_fn()\n\n  def testNumpyInputFnWithYIsNone(self):\n    a = np.arange(4) * 1.0\n    b = np.arange(32, 36)\n    x = {'a': a, 'b': b}\n    y = None\n\n    with self.cached_session() as session:\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n      features_tensor = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      feature = session.run(features_tensor)\n      self.assertEqual(len(feature), 2)\n      self.assertAllEqual(feature['a'], [0, 1])\n      self.assertAllEqual(feature['b'], [32, 33])\n\n      session.run([features_tensor])\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features_tensor])\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testNumpyInputFnWithNonBoolShuffle(self):\n    x = np.arange(32, 36)\n    y = np.arange(4)\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          ValueError, 'shuffle must be provided and explicitly '\n          'set as boolean'):\n        # Default shuffle is None.\n        numpy_io.numpy_input_fn(x, y)\n\n  def testNumpyInputFnWithTargetKeyAlreadyInX(self):\n    array = np.arange(32, 36)\n    x = {'__target_key__': array}\n    y = np.arange(4)\n\n    with self.cached_session():\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n      input_fn()\n      self.assertAllEqual(x['__target_key__'], array)\n      # The input x should not be mutated.\n      self.assertItemsEqual(x.keys(), ['__target_key__'])\n\n  def testNumpyInputFnWithMismatchLengthOfInputs(self):\n    a = np.arange(4) * 1.0\n    b = np.arange(32, 36)\n    x = {'a': a, 'b': b}\n    x_mismatch_length = {'a': np.arange(1), 'b': b}\n    y_longer_length = np.arange(10)\n\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          ValueError, 'Length of tensors in x and y is mismatched.'):\n        failing_input_fn = numpy_io.numpy_input_fn(\n            x, y_longer_length, batch_size=2, shuffle=False, num_epochs=1)\n        failing_input_fn()\n\n      with self.assertRaisesRegexp(\n          ValueError, 'Length of tensors in x and y is mismatched.'):\n        failing_input_fn = numpy_io.numpy_input_fn(\n            x=x_mismatch_length,\n            y=None,\n            batch_size=2,\n            shuffle=False,\n            num_epochs=1)\n        failing_input_fn()\n\n  def testNumpyInputFnWithYAsDict(self):\n    a = np.arange(4) * 1.0\n    b = np.arange(32, 36)\n    x = {'a': a, 'b': b}\n    y = {'y1': np.arange(-32, -28), 'y2': np.arange(32, 28, -1)}\n\n    with self.cached_session() as session:\n      input_fn = numpy_io.numpy_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n      features_tensor, targets_tensor = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      features, targets = session.run([features_tensor, targets_tensor])\n      self.assertEqual(len(features), 2)\n      self.assertAllEqual(features['a'], [0, 1])\n      self.assertAllEqual(features['b'], [32, 33])\n      self.assertEqual(len(targets), 2)\n      self.assertAllEqual(targets['y1'], [-32, -31])\n      self.assertAllEqual(targets['y2'], [32, 31])\n\n      session.run([features_tensor, targets_tensor])\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features_tensor, targets_tensor])\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testNumpyInputFnWithYIsEmptyDict(self):\n    a = np.arange(4) * 1.0\n    b = np.arange(32, 36)\n    x = {'a': a, 'b': b}\n    y = {}\n    with self.cached_session():\n      with self.assertRaisesRegexp(ValueError, 'y cannot be empty'):\n        failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)\n        failing_input_fn()\n\n  def testNumpyInputFnWithDuplicateKeysInXAndY(self):\n    a = np.arange(4) * 1.0\n    b = np.arange(32, 36)\n    x = {'a': a, 'b': b}\n    y = {'y1': np.arange(-32, -28), 'a': a, 'y2': np.arange(32, 28, -1), 'b': b}\n    with self.cached_session():\n      with self.assertRaisesRegexp(\n          ValueError, '2 duplicate keys are found in both x and y'):\n        failing_input_fn = numpy_io.numpy_input_fn(x, y, shuffle=False)\n        failing_input_fn()\n\n  def testNumpyInputFnWithXIsArray(self):\n    x = np.arange(4) * 1.0\n    y = np.arange(-32, -28)\n\n    input_fn = numpy_io.numpy_input_fn(\n        x, y, batch_size=2, shuffle=False, num_epochs=1)\n    features, target = input_fn()\n\n    with tf.compat.v1.train.MonitoredSession() as session:\n      res = session.run([features, target])\n      self.assertAllEqual(res[0], [0, 1])\n      self.assertAllEqual(res[1], [-32, -31])\n\n      session.run([features, target])\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features, target])\n\n  def testNumpyInputFnWithXIsNDArray(self):\n    x = np.arange(16).reshape(4, 2, 2) * 1.0\n    y = np.arange(-48, -32).reshape(4, 2, 2)\n\n    input_fn = numpy_io.numpy_input_fn(\n        x, y, batch_size=2, shuffle=False, num_epochs=1)\n    features, target = input_fn()\n\n    with tf.compat.v1.train.MonitoredSession() as session:\n      res = session.run([features, target])\n      self.assertAllEqual(res[0], [[[0, 1], [2, 3]], [[4, 5], [6, 7]]])\n      self.assertAllEqual(res[1],\n                          [[[-48, -47], [-46, -45]], [[-44, -43], [-42, -41]]])\n\n      session.run([features, target])\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features, target])\n\n  def testNumpyInputFnWithXIsArrayYIsDict(self):\n    x = np.arange(4) * 1.0\n    y = {'y1': np.arange(-32, -28)}\n\n    input_fn = numpy_io.numpy_input_fn(\n        x, y, batch_size=2, shuffle=False, num_epochs=1)\n    features_tensor, targets_tensor = input_fn()\n\n    with tf.compat.v1.train.MonitoredSession() as session:\n      features, targets = session.run([features_tensor, targets_tensor])\n      self.assertEqual(len(features), 2)\n      self.assertAllEqual(features, [0, 1])\n      self.assertEqual(len(targets), 1)\n      self.assertAllEqual(targets['y1'], [-32, -31])\n\n      session.run([features_tensor, targets_tensor])\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run([features_tensor, targets_tensor])\n\n  def testArrayAndDictGiveSameOutput(self):\n    a = np.arange(4) * 1.0\n    b = np.arange(32, 36)\n    x_arr = np.vstack((a, b))\n    x_dict = {'feature1': x_arr}\n    y = np.arange(-48, -40).reshape(2, 4)\n\n    input_fn_arr = numpy_io.numpy_input_fn(\n        x_arr, y, batch_size=2, shuffle=False, num_epochs=1)\n    features_arr, targets_arr = input_fn_arr()\n\n    input_fn_dict = numpy_io.numpy_input_fn(\n        x_dict, y, batch_size=2, shuffle=False, num_epochs=1)\n    features_dict, targets_dict = input_fn_dict()\n\n    with tf.compat.v1.train.MonitoredSession() as session:\n      res_arr, res_dict = session.run([(features_arr, targets_arr),\n                                       (features_dict, targets_dict)])\n\n      self.assertAllEqual(res_arr[0], res_dict[0]['feature1'])\n      self.assertAllEqual(res_arr[1], res_dict[1])\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass FeatureColumnIntegrationTest(tf.test.TestCase):\n\n  def _initialized_session(self, config=None):\n    sess = tf.compat.v1.Session(config=config)\n    sess.run(tf.compat.v1.initializers.global_variables())\n    sess.run(tf.compat.v1.initializers.tables_initializer())\n    return sess\n\n  def _get_linear_model_bias(self, name='linear_model'):\n    with tf.compat.v1.variable_scope(name, reuse=True):\n      return tf.compat.v1.get_variable('bias_weights')\n\n  def _get_linear_model_column_var(self, column, name='linear_model'):\n    return tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,\n                                       name + '/' + column.name)[0]\n\n  def _get_keras_linear_model_predictions(self,\n                                          features,\n                                          feature_columns,\n                                          units=1,\n                                          sparse_combiner='sum',\n                                          weight_collections=None,\n                                          trainable=True,\n                                          cols_to_vars=None):\n    keras_linear_model = _LinearModel(\n        feature_columns,\n        units,\n        sparse_combiner,\n        weight_collections,\n        trainable,\n        name='linear_model')\n    retval = keras_linear_model(features)  # pylint: disable=not-callable\n    if cols_to_vars is not None:\n      cols_to_vars.update(keras_linear_model.cols_to_vars())\n    return retval\n\n  def test_linear_model_numpy_input_fn(self):\n    price = tf.feature_column.numeric_column('price')\n    price_buckets = tf.feature_column.bucketized_column(\n        price, boundaries=[0., 10., 100.,])\n    body_style = tf.feature_column.categorical_column_with_vocabulary_list(\n        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])\n\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'price': np.array([-1., 2., 13., 104.]),\n            'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),\n        },\n        batch_size=2,\n        shuffle=False)\n    features = input_fn()\n    net = tf.compat.v1.feature_column.linear_model(features,\n                                                   [price_buckets, body_style])\n    # self.assertEqual(1 + 3 + 5, net.shape[1])\n    with self._initialized_session() as sess:\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          sess, coord=coord)\n\n      bias = self._get_linear_model_bias()\n      price_buckets_var = self._get_linear_model_column_var(price_buckets)\n      body_style_var = self._get_linear_model_column_var(body_style)\n\n      sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))\n      sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))\n      sess.run(bias.assign([5.]))\n\n      self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def test_linear_model_impl_numpy_input_fn(self):\n    price = tf.feature_column.numeric_column('price')\n    price_buckets = tf.feature_column.bucketized_column(\n        price, boundaries=[\n            0.,\n            10.,\n            100.,\n        ])\n    body_style = tf.feature_column.categorical_column_with_vocabulary_list(\n        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])\n\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'price': np.array([-1., 2., 13., 104.]),\n            'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),\n        },\n        batch_size=2,\n        shuffle=False)\n    features = input_fn()\n    net = self._get_keras_linear_model_predictions(features,\n                                                   [price_buckets, body_style])\n    # self.assertEqual(1 + 3 + 5, net.shape[1])\n    with self._initialized_session() as sess:\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          sess, coord=coord)\n\n      bias = self._get_linear_model_bias()\n      price_buckets_var = self._get_linear_model_column_var(price_buckets)\n      body_style_var = self._get_linear_model_column_var(body_style)\n\n      sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))\n      sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))\n      sess.run(bias.assign([5.]))\n\n      self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def test_functional_input_layer_with_numpy_input_fn(self):\n    embedding_values = (\n        (1., 2., 3., 4., 5.),  # id 0\n        (6., 7., 8., 9., 10.),  # id 1\n        (11., 12., 13., 14., 15.)  # id 2\n    )\n\n    def _initializer(shape, dtype, partition_info):\n      del shape, dtype, partition_info\n      return embedding_values\n\n    # price has 1 dimension in input_layer\n    price = tf.feature_column.numeric_column('price')\n    body_style = tf.feature_column.categorical_column_with_vocabulary_list(\n        'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])\n    # one_hot_body_style has 3 dims in input_layer.\n    one_hot_body_style = tf.feature_column.indicator_column(body_style)\n    # embedded_body_style has 5 dims in input_layer.\n    embedded_body_style = tf.feature_column.embedding_column(\n        body_style, dimension=5, initializer=_initializer)\n\n    input_fn = numpy_io.numpy_input_fn(\n        x={\n            'price': np.array([11., 12., 13., 14.]),\n            'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),\n        },\n        batch_size=2,\n        shuffle=False)\n    features = input_fn()\n    net = tf.compat.v1.feature_column.input_layer(\n        features, [price, one_hot_body_style, embedded_body_style])\n    self.assertEqual(1 + 3 + 5, net.shape[1])\n    with self._initialized_session() as sess:\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          sess, coord=coord)\n\n      # Each row is formed by concatenating `embedded_body_style`,\n      # `one_hot_body_style`, and `price` in order.\n      self.assertAllEqual([[11., 12., 13., 14., 15., 0., 0., 1., 11.],\n                           [1., 2., 3., 4., 5., 1., 0., 0., 12]], sess.run(net))\n\n      coord.request_stop()\n      coord.join(threads)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/pandas_io.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Methods to allow pandas.DataFrame.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport uuid\nimport numpy as np\nimport six\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.inputs.queues import feeding_functions\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  # pylint: disable=unused-import\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n\ndef _get_unique_target_key(features, target_column_name):\n  \"\"\"Returns a key that does not exist in the input DataFrame `features`.\n\n  Args:\n    features: DataFrame\n    target_column_name: Name of the target column as a `str`\n\n  Returns:\n    A unique key that can be used to insert the target into\n      features.\n  \"\"\"\n  if target_column_name in features:\n    target_column_name += '_' + str(uuid.uuid4())\n  return target_column_name\n\n\n@estimator_export(v1=['estimator.inputs.pandas_input_fn'])\ndef pandas_input_fn(x,\n                    y=None,\n                    batch_size=128,\n                    num_epochs=1,\n                    shuffle=None,\n                    queue_capacity=1000,\n                    num_threads=1,\n                    target_column='target'):\n  \"\"\"Returns input function that would feed Pandas DataFrame into the model.\n\n  Note: `y`'s index must match `x`'s index.\n\n  Args:\n    x: pandas `DataFrame` object.\n    y: pandas `Series` object or `DataFrame`. `None` if absent.\n    batch_size: int, size of batches to return.\n    num_epochs: int, number of epochs to iterate over data. If not `None`, read\n      attempts that would exceed this value will raise `OutOfRangeError`.\n    shuffle: bool, whether to read the records in random order.\n    queue_capacity: int, size of the read queue. If `None`, it will be set\n      roughly to the size of `x`.\n    num_threads: Integer, number of threads used for reading and enqueueing. In\n      order to have predicted and repeatable order of reading and enqueueing,\n      such as in prediction and evaluation mode, `num_threads` should be 1.\n    target_column: str, name to give the target column `y`. This parameter is\n      not used when `y` is a `DataFrame`.\n\n  Returns:\n    Function, that has signature of ()->(dict of `features`, `target`)\n\n  Raises:\n    ValueError: if `x` already contains a column with the same name as `y`, or\n      if the indexes of `x` and `y` don't match.\n    ValueError: if 'shuffle' is not provided or a bool.\n  \"\"\"\n  if not HAS_PANDAS:\n    raise TypeError(\n        'pandas_input_fn should not be called without pandas installed')\n\n  if not isinstance(shuffle, bool):\n    raise ValueError('shuffle must be provided and explicitly set as boolean '\n                     '(it is recommended to set it as True for training); '\n                     'got {}'.format(shuffle))\n\n  if not isinstance(target_column, six.string_types):\n    raise TypeError('target_column must be a string type')\n\n  x = x.copy()\n  if y is not None:\n    if target_column in x:\n      raise ValueError(\n          'Cannot use name %s for target column: DataFrame already has a '\n          'column with that name: %s' % (target_column, x.columns))\n    if not np.array_equal(x.index, y.index):\n      raise ValueError('Index for x and y are mismatched.\\nIndex for x: %s\\n'\n                       'Index for y: %s\\n' % (x.index, y.index))\n    if isinstance(y, pd.DataFrame):\n      y_columns = [\n          (column, _get_unique_target_key(x, column)) for column in list(y)\n      ]\n      target_column = [v for _, v in y_columns]\n      x[target_column] = y\n    else:\n      x[target_column] = y\n\n  # TODO(mdan): These are memory copies. We probably don't need 4x slack space.\n  # The sizes below are consistent with what I've seen elsewhere.\n  if queue_capacity is None:\n    if shuffle:\n      queue_capacity = 4 * len(x)\n    else:\n      queue_capacity = len(x)\n  min_after_dequeue = max(queue_capacity / 4, 1)\n\n  def input_fn():\n    \"\"\"Pandas input function.\"\"\"\n    queue = feeding_functions._enqueue_data(  # pylint: disable=protected-access\n        x,\n        queue_capacity,\n        shuffle=shuffle,\n        min_after_dequeue=min_after_dequeue,\n        num_threads=num_threads,\n        enqueue_size=batch_size,\n        num_epochs=num_epochs)\n    if num_epochs is None:\n      features = queue.dequeue_many(batch_size)\n    else:\n      features = queue.dequeue_up_to(batch_size)\n    assert len(features) == len(x.columns) + 1, ('Features should have one '\n                                                 'extra element for the index.')\n    features = features[1:]\n    features = dict(zip(list(x.columns), features))\n    if y is not None:\n      if isinstance(target_column, list):\n        keys = [k for k, _ in y_columns]\n        values = [features.pop(column) for column in target_column]\n        target = {k: v for k, v in zip(keys, values)}\n      else:\n        target = features.pop(target_column)\n      return features, target\n    return features\n\n  return input_fn\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/pandas_io_test.py",
    "content": "# Copyright 2015 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for pandas_io.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.inputs import pandas_io\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n\n@test_util.run_v1_only('Tests v1 only symbols')\nclass PandasIoTest(tf.test.TestCase):\n\n  def makeTestDataFrame(self):\n    index = np.arange(100, 104)\n    a = np.arange(4)\n    b = np.arange(32, 36)\n    x = pd.DataFrame({'a': a, 'b': b}, index=index)\n    y = pd.Series(np.arange(-32, -28), index=index)\n    return x, y\n\n  def makeTestDataFrameWithYAsDataFrame(self):\n    index = np.arange(100, 104)\n    a = np.arange(4)\n    b = np.arange(32, 36)\n    a_label = np.arange(10, 14)\n    b_label = np.arange(50, 54)\n    x = pd.DataFrame({'a': a, 'b': b}, index=index)\n    y = pd.DataFrame({'a_target': a_label, 'b_target': b_label}, index=index)\n    return x, y\n\n  def callInputFnOnce(self, input_fn, session):\n    results = input_fn()\n    coord = tf.train.Coordinator()\n    threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n        session, coord=coord)\n    result_values = session.run(results)\n    coord.request_stop()\n    coord.join(threads)\n    return result_values\n\n  def testPandasInputFn_IndexMismatch(self):\n    if not HAS_PANDAS:\n      return\n    x, _ = self.makeTestDataFrame()\n    y_noindex = pd.Series(np.arange(-32, -28))\n    with self.assertRaises(ValueError):\n      pandas_io.pandas_input_fn(\n          x, y_noindex, batch_size=2, shuffle=False, num_epochs=1)\n\n  def testPandasInputFn_RaisesWhenTargetColumnIsAList(self):\n    if not HAS_PANDAS:\n      return\n\n    x, y = self.makeTestDataFrame()\n\n    with self.assertRaisesRegexp(TypeError,\n                                 'target_column must be a string type'):\n      pandas_io.pandas_input_fn(\n          x,\n          y,\n          batch_size=2,\n          shuffle=False,\n          num_epochs=1,\n          target_column=['one', 'two'])\n\n  def testPandasInputFn_NonBoolShuffle(self):\n    if not HAS_PANDAS:\n      return\n    x, _ = self.makeTestDataFrame()\n    y_noindex = pd.Series(np.arange(-32, -28))\n    with self.assertRaisesRegexp(\n        ValueError, 'shuffle must be provided and explicitly '\n        'set as boolean'):\n      # Default shuffle is None\n      pandas_io.pandas_input_fn(x, y_noindex)\n\n  def testPandasInputFn_ProducesExpectedOutputs(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      x, y = self.makeTestDataFrame()\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n\n      features, target = self.callInputFnOnce(input_fn, session)\n\n      self.assertAllEqual(features['a'], [0, 1])\n      self.assertAllEqual(features['b'], [32, 33])\n      self.assertAllEqual(target, [-32, -31])\n\n  def testPandasInputFnWhenYIsDataFrame_ProducesExpectedOutput(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      x, y = self.makeTestDataFrameWithYAsDataFrame()\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n\n      features, targets = self.callInputFnOnce(input_fn, session)\n\n      self.assertAllEqual(features['a'], [0, 1])\n      self.assertAllEqual(features['b'], [32, 33])\n      self.assertAllEqual(targets['a_target'], [10, 11])\n      self.assertAllEqual(targets['b_target'], [50, 51])\n\n  def testPandasInputFnYIsDataFrame_HandlesOverlappingColumns(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      x, y = self.makeTestDataFrameWithYAsDataFrame()\n      y = y.rename(columns={'a_target': 'a', 'b_target': 'b'})\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n\n      features, targets = self.callInputFnOnce(input_fn, session)\n\n      self.assertAllEqual(features['a'], [0, 1])\n      self.assertAllEqual(features['b'], [32, 33])\n      self.assertAllEqual(targets['a'], [10, 11])\n      self.assertAllEqual(targets['b'], [50, 51])\n\n  def testPandasInputFnYIsDataFrame_HandlesOverlappingColumnsInTargets(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      x, y = self.makeTestDataFrameWithYAsDataFrame()\n      y = y.rename(columns={'a_target': 'a', 'b_target': 'a_n'})\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n\n      features, targets = self.callInputFnOnce(input_fn, session)\n\n      self.assertAllEqual(features['a'], [0, 1])\n      self.assertAllEqual(features['b'], [32, 33])\n      self.assertAllEqual(targets['a'], [10, 11])\n      self.assertAllEqual(targets['a_n'], [50, 51])\n\n  def testPandasInputFn_ProducesOutputsForLargeBatchAndMultipleEpochs(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      index = np.arange(100, 102)\n      a = np.arange(2)\n      b = np.arange(32, 34)\n      x = pd.DataFrame({'a': a, 'b': b}, index=index)\n      y = pd.Series(np.arange(-32, -30), index=index)\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=128, shuffle=False, num_epochs=2)\n\n      results = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      features, target = session.run(results)\n      self.assertAllEqual(features['a'], [0, 1, 0, 1])\n      self.assertAllEqual(features['b'], [32, 33, 32, 33])\n      self.assertAllEqual(target, [-32, -31, -32, -31])\n\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run(results)\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testPandasInputFn_ProducesOutputsWhenDataSizeNotDividedByBatchSize(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      index = np.arange(100, 105)\n      a = np.arange(5)\n      b = np.arange(32, 37)\n      x = pd.DataFrame({'a': a, 'b': b}, index=index)\n      y = pd.Series(np.arange(-32, -27), index=index)\n\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n\n      results = input_fn()\n\n      coord = tf.train.Coordinator()\n      threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n          session, coord=coord)\n\n      features, target = session.run(results)\n      self.assertAllEqual(features['a'], [0, 1])\n      self.assertAllEqual(features['b'], [32, 33])\n      self.assertAllEqual(target, [-32, -31])\n\n      features, target = session.run(results)\n      self.assertAllEqual(features['a'], [2, 3])\n      self.assertAllEqual(features['b'], [34, 35])\n      self.assertAllEqual(target, [-30, -29])\n\n      features, target = session.run(results)\n      self.assertAllEqual(features['a'], [4])\n      self.assertAllEqual(features['b'], [36])\n      self.assertAllEqual(target, [-28])\n\n      with self.assertRaises(tf.errors.OutOfRangeError):\n        session.run(results)\n\n      coord.request_stop()\n      coord.join(threads)\n\n  def testPandasInputFn_OnlyX(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      x, _ = self.makeTestDataFrame()\n      input_fn = pandas_io.pandas_input_fn(\n          x, y=None, batch_size=2, shuffle=False, num_epochs=1)\n\n      features = self.callInputFnOnce(input_fn, session)\n\n      self.assertAllEqual(features['a'], [0, 1])\n      self.assertAllEqual(features['b'], [32, 33])\n\n  def testPandasInputFn_ExcludesIndex(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      x, y = self.makeTestDataFrame()\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)\n\n      features, _ = self.callInputFnOnce(input_fn, session)\n\n      self.assertFalse('index' in features)\n\n  def assertInputsCallableNTimes(self, input_fn, session, n):\n    inputs = input_fn()\n    coord = tf.train.Coordinator()\n    threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n        session, coord=coord)\n    for _ in range(n):\n      session.run(inputs)\n    with self.assertRaises(tf.errors.OutOfRangeError):\n      session.run(inputs)\n    coord.request_stop()\n    coord.join(threads)\n\n  def testPandasInputFn_RespectsEpoch_NoShuffle(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      x, y = self.makeTestDataFrame()\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=4, shuffle=False, num_epochs=1)\n\n      self.assertInputsCallableNTimes(input_fn, session, 1)\n\n  def testPandasInputFn_RespectsEpoch_WithShuffle(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      x, y = self.makeTestDataFrame()\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=4, shuffle=True, num_epochs=1)\n\n      self.assertInputsCallableNTimes(input_fn, session, 1)\n\n  def testPandasInputFn_RespectsEpoch_WithShuffleAutosize(self):\n    if not HAS_PANDAS:\n      return\n    with self.cached_session() as session:\n      x, y = self.makeTestDataFrame()\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=2, shuffle=True, queue_capacity=None, num_epochs=2)\n\n      self.assertInputsCallableNTimes(input_fn, session, 4)\n\n  def testPandasInputFn_RespectsEpochUnevenBatches(self):\n    if not HAS_PANDAS:\n      return\n    x, y = self.makeTestDataFrame()\n    with self.cached_session() as session:\n      input_fn = pandas_io.pandas_input_fn(\n          x, y, batch_size=3, shuffle=False, num_epochs=1)\n\n      # Before the last batch, only one element of the epoch should remain.\n      self.assertInputsCallableNTimes(input_fn, session, 2)\n\n  def testPandasInputFn_Idempotent(self):\n    if not HAS_PANDAS:\n      return\n    x, y = self.makeTestDataFrame()\n    for _ in range(2):\n      pandas_io.pandas_input_fn(\n          x, y, batch_size=2, shuffle=False, num_epochs=1)()\n    for _ in range(2):\n      pandas_io.pandas_input_fn(\n          x, y, batch_size=2, shuffle=True, num_epochs=1)()\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/queues/__init__.py",
    "content": ""
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Helper functions for enqueuing data from arrays and pandas `DataFrame`s.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport random\nimport types as tp\nimport numpy as np\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.framework import ops\nfrom tensorflow_estimator.python.estimator.inputs.queues import feeding_queue_runner as fqr\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n\ndef _fill_array(arr, seq, fillvalue=0):\n  \"\"\"Recursively fills padded arr with elements from seq.\n\n  If length of seq is less than arr padded length, fillvalue used.\n  Args:\n    arr: Padded tensor of shape [batch_size, ..., max_padded_dim_len].\n    seq: Non-padded list of data samples of shape\n      [batch_size, ..., padded_dim(None)]\n    fillvalue: Default fillvalue to use.\n  \"\"\"\n  if arr.ndim == 1:\n    try:\n      len_ = len(seq)\n    except TypeError:\n      len_ = 0\n    arr[:len_] = seq\n    arr[len_:] = fillvalue\n  else:\n    for subarr, subseq in six.moves.zip_longest(arr, seq, fillvalue=()):\n      _fill_array(subarr, subseq, fillvalue)\n\n\ndef _pad_if_needed(batch_key_item, fillvalue=0):\n  \"\"\" Returns padded batch.\n\n  Args:\n    batch_key_item: List of data samples of any type with shape\n      [batch_size, ..., padded_dim(None)].\n    fillvalue: Default fillvalue to use.\n\n  Returns:\n    Padded with zeros tensor of same type and shape\n      [batch_size, ..., max_padded_dim_len].\n\n  Raises:\n    ValueError if data samples have different shapes (except last padded dim).\n  \"\"\"\n  shapes = [\n      seq.shape[:-1] if len(seq.shape) > 0 else -1 for seq in batch_key_item\n  ]\n  if not all(shapes[0] == x for x in shapes):\n    raise ValueError(\"Array shapes must match.\")\n\n  last_length = [\n      seq.shape[-1] if len(seq.shape) > 0 else 0 for seq in batch_key_item\n  ]\n  if all([x == last_length[0] for x in last_length]):\n    return batch_key_item\n\n  batch_size = len(batch_key_item)\n  max_sequence_length = max(last_length)\n  result_batch = np.zeros(\n      shape=[batch_size] + list(shapes[0]) + [max_sequence_length],\n      dtype=batch_key_item[0].dtype)\n  _fill_array(result_batch, batch_key_item, fillvalue)\n  return result_batch\n\n\ndef _get_integer_indices_for_next_batch(batch_indices_start, batch_size,\n                                        epoch_end, array_length, current_epoch,\n                                        total_epochs):\n  \"\"\"Returns the integer indices for next batch.\n\n  If total epochs is not None and current epoch is the final epoch, the end\n  index of the next batch should not exceed the `epoch_end` (i.e., the final\n  batch might not have size `batch_size` to avoid overshooting the last epoch).\n\n  Args:\n    batch_indices_start: Integer, the index to start next batch.\n    batch_size: Integer, size of batches to return.\n    epoch_end: Integer, the end index of the epoch. The epoch could start from a\n      random position, so `epoch_end` provides the end index for that.\n    array_length: Integer, the length of the array.\n    current_epoch: Integer, the epoch number has been emitted.\n    total_epochs: Integer or `None`, the total number of epochs to emit. If\n      `None` will run forever.\n\n  Returns:\n    A tuple of a list with integer indices for next batch and `current_epoch`\n    value after the next batch.\n\n  Raises:\n    OutOfRangeError if `current_epoch` is not less than `total_epochs`.\n\n  \"\"\"\n  if total_epochs is not None and current_epoch >= total_epochs:\n    raise tf.errors.OutOfRangeError(\n        None, None, \"Already emitted %s epochs.\" % current_epoch)\n\n  batch_indices_end = batch_indices_start + batch_size\n  batch_indices = [\n      j % array_length for j in range(batch_indices_start, batch_indices_end)\n  ]\n  epoch_end_indices = [i for i, x in enumerate(batch_indices) if x == epoch_end]\n  current_epoch += len(epoch_end_indices)\n\n  if total_epochs is None or current_epoch < total_epochs:\n    return (batch_indices, current_epoch)\n\n  # Now we might have emitted more data for expected epochs. Need to trim.\n  final_epoch_end_inclusive = epoch_end_indices[-(current_epoch - total_epochs +\n                                                  1)]\n  batch_indices = batch_indices[:final_epoch_end_inclusive + 1]\n\n  return (batch_indices, total_epochs)\n\n\nclass _ArrayFeedFn(object):\n  \"\"\"Creates feed dictionaries from numpy arrays.\"\"\"\n\n  def __init__(self,\n               placeholders,\n               array,\n               batch_size,\n               random_start=False,\n               seed=None,\n               num_epochs=None):\n    if len(placeholders) != 2:\n      raise ValueError(\"_array_feed_fn expects 2 placeholders; got {}.\".format(\n          len(placeholders)))\n    self._placeholders = placeholders\n    self._array = array\n    self._max = len(array)\n    self._batch_size = batch_size\n    self._num_epochs = num_epochs\n    self._epoch = 0\n    random.seed(seed)\n    self._trav = random.randrange(self._max) if random_start else 0\n    self._epoch_end = (self._trav - 1) % self._max\n\n  def __call__(self):\n    integer_indexes, self._epoch = _get_integer_indices_for_next_batch(\n        batch_indices_start=self._trav,\n        batch_size=self._batch_size,\n        epoch_end=self._epoch_end,\n        array_length=self._max,\n        current_epoch=self._epoch,\n        total_epochs=self._num_epochs)\n\n    self._trav = (integer_indexes[-1] + 1) % self._max\n    return {\n        self._placeholders[0]: integer_indexes,\n        self._placeholders[1]: self._array[integer_indexes]\n    }\n\n\nclass _OrderedDictNumpyFeedFn(object):\n  \"\"\"Creates feed dictionaries from `OrderedDict`s of numpy arrays.\"\"\"\n\n  def __init__(self,\n               placeholders,\n               ordered_dict_of_arrays,\n               batch_size,\n               random_start=False,\n               seed=None,\n               num_epochs=None):\n    if len(placeholders) != len(ordered_dict_of_arrays) + 1:\n      raise ValueError(\"Expected {} placeholders; got {}.\".format(\n          len(ordered_dict_of_arrays), len(placeholders)))\n    self._index_placeholder = placeholders[0]\n    self._col_placeholders = placeholders[1:]\n    self._ordered_dict_of_arrays = ordered_dict_of_arrays\n    self._max = len(next(iter(ordered_dict_of_arrays.values())))\n    for _, v in ordered_dict_of_arrays.items():\n      if len(v) != self._max:\n        raise ValueError(\"Array lengths must match.\")\n    self._batch_size = batch_size\n    self._num_epochs = num_epochs\n    self._epoch = 0\n    random.seed(seed)\n    self._trav = random.randrange(self._max) if random_start else 0\n    self._epoch_end = (self._trav - 1) % self._max\n\n  def __call__(self):\n    integer_indexes, self._epoch = _get_integer_indices_for_next_batch(\n        batch_indices_start=self._trav,\n        batch_size=self._batch_size,\n        epoch_end=self._epoch_end,\n        array_length=self._max,\n        current_epoch=self._epoch,\n        total_epochs=self._num_epochs)\n\n    self._trav = (integer_indexes[-1] + 1) % self._max\n    feed_dict = {self._index_placeholder: integer_indexes}\n    cols = [\n        column[integer_indexes]\n        for column in self._ordered_dict_of_arrays.values()\n    ]\n    feed_dict.update(dict(zip(self._col_placeholders, cols)))\n    return feed_dict\n\n\nclass _PandasFeedFn(object):\n  \"\"\"Creates feed dictionaries from pandas `DataFrames`.\"\"\"\n\n  def __init__(self,\n               placeholders,\n               dataframe,\n               batch_size,\n               random_start=False,\n               seed=None,\n               num_epochs=None):\n    if len(placeholders) != len(dataframe.columns) + 1:\n      raise ValueError(\"Expected {} placeholders; got {}.\".format(\n          len(dataframe.columns) + 1, len(placeholders)))\n    self._index_placeholder = placeholders[0]\n    self._col_placeholders = placeholders[1:]\n    self._dataframe = dataframe\n    self._max = len(dataframe)\n    self._batch_size = batch_size\n    self._num_epochs = num_epochs\n    self._epoch = 0\n    random.seed(seed)\n    self._trav = random.randrange(self._max) if random_start else 0\n    self._epoch_end = (self._trav - 1) % self._max\n\n  def __call__(self):\n    integer_indexes, self._epoch = _get_integer_indices_for_next_batch(\n        batch_indices_start=self._trav,\n        batch_size=self._batch_size,\n        epoch_end=self._epoch_end,\n        array_length=self._max,\n        current_epoch=self._epoch,\n        total_epochs=self._num_epochs)\n\n    self._trav = (integer_indexes[-1] + 1) % self._max\n    result = self._dataframe.iloc[integer_indexes]\n    cols = [result[col].values for col in result.columns]\n    feed_dict = dict(zip(self._col_placeholders, cols))\n    feed_dict[self._index_placeholder] = result.index.values\n    return feed_dict\n\n\nclass _GeneratorFeedFn(object):\n  \"\"\"Creates feed dictionaries from `Generator` of `dicts` of numpy arrays.\"\"\"\n\n  def __init__(self,\n               placeholders,\n               generator,\n               batch_size,\n               random_start=False,\n               seed=None,\n               num_epochs=None,\n               pad_value=None):\n    first_sample = next(generator())\n    if len(placeholders) != len(first_sample):\n      raise ValueError(\"Expected {} placeholders; got {}.\".format(\n          len(first_sample), len(placeholders)))\n    self._keys = sorted(list(first_sample.keys()))\n    self._col_placeholders = placeholders\n    self._generator_function = generator\n    self._iterator = generator()\n    self._batch_size = batch_size\n    self._num_epochs = num_epochs\n    self._epoch = 0\n    self._pad_value = pad_value\n    random.seed(seed)\n\n  def __call__(self):\n    if self._num_epochs and self._epoch >= self._num_epochs:\n      raise tf.errors.OutOfRangeError(\n          None, None, \"Already emitted %s epochs.\" % self._epoch)\n    list_dict = {}\n    list_dict_size = 0\n    while list_dict_size < self._batch_size:\n      try:\n        data_row = next(self._iterator)\n      except StopIteration:\n        self._epoch += 1\n        self._iterator = self._generator_function()\n        data_row = next(self._iterator)\n      for index, key in enumerate(self._keys):\n        if key not in data_row.keys():\n          raise KeyError(\"key mismatch between dicts emitted by GenFun \"\n                         \"Expected {} keys; got {}\".format(\n                             self._keys, data_row.keys()))\n        list_dict.setdefault(self._col_placeholders[index],\n                             list()).append(data_row[key])\n        list_dict_size += 1\n\n    if self._pad_value is not None:\n      feed_dict = {\n          key: np.asarray(_pad_if_needed(item, self._pad_value))\n          for key, item in list(list_dict.items())\n      }\n    else:\n      feed_dict = {\n          key: np.asarray(item) for key, item in list(list_dict.items())\n      }\n    return feed_dict\n\n\ndef _enqueue_data(data,\n                  capacity,\n                  shuffle=False,\n                  min_after_dequeue=None,\n                  num_threads=1,\n                  seed=None,\n                  name=\"enqueue_input\",\n                  enqueue_size=1,\n                  num_epochs=None,\n                  pad_value=None):\n  \"\"\"Creates a queue filled from a numpy array or pandas `DataFrame`.\n\n    Returns a queue filled with the rows of the given (`OrderedDict` of) array\n    or `DataFrame`. In the case of a pandas `DataFrame`, the first enqueued\n    `Tensor` corresponds to the index of the `DataFrame`. For (`OrderedDict` of)\n    numpy arrays, the first enqueued `Tensor` contains the row number.\n\n  Args:\n    data: a numpy `ndarray`, `OrderedDict` of numpy arrays, or a generator\n      yielding `dict`s of numpy arrays or pandas `DataFrame` that will be read\n      into the queue.\n    capacity: the capacity of the queue.\n    shuffle: whether or not to shuffle the rows of the array.\n    min_after_dequeue: minimum number of elements that can remain in the queue\n      after a dequeue operation. Only used when `shuffle` is true. If not set,\n      defaults to `capacity` / 4.\n    num_threads: number of threads used for reading and enqueueing.\n    seed: used to seed shuffling and reader starting points.\n    name: a scope name identifying the data.\n    enqueue_size: the number of rows to enqueue per step.\n    num_epochs: limit enqueuing to a specified number of epochs, if provided.\n    pad_value: default value for dynamic padding of data samples, if provided.\n\n  Returns:\n    A queue filled with the rows of the given (`OrderedDict` of) array or\n      `DataFrame`.\n\n  Raises:\n    TypeError: `data` is not a Pandas `DataFrame`, an `OrderedDict` of numpy\n      arrays, a numpy `ndarray`, or a generator producing these.\n    NotImplementedError: padding and shuffling data at the same time.\n    NotImplementedError: padding usage with non generator data type.\n  \"\"\"\n  with ops.name_scope(name):\n    if isinstance(data, np.ndarray):\n      types = [tf.dtypes.int64, tf.dtypes.as_dtype(data.dtype)]\n      queue_shapes = [(), data.shape[1:]]\n      get_feed_fn = _ArrayFeedFn\n    elif isinstance(data, collections.OrderedDict):\n      types = [tf.dtypes.int64\n              ] + [tf.dtypes.as_dtype(col.dtype) for col in data.values()]\n      queue_shapes = [()] + [col.shape[1:] for col in data.values()]\n      get_feed_fn = _OrderedDictNumpyFeedFn\n    elif isinstance(data, tp.FunctionType):\n      x_first_el = six.next(data())\n      x_first_keys = sorted(x_first_el.keys())\n      x_first_values = [x_first_el[key] for key in x_first_keys]\n      types = [tf.dtypes.as_dtype(col.dtype) for col in x_first_values]\n      queue_shapes = [col.shape for col in x_first_values]\n      get_feed_fn = _GeneratorFeedFn\n    elif HAS_PANDAS and isinstance(data, pd.DataFrame):\n      types = [\n          tf.dtypes.as_dtype(dt)\n          for dt in [data.index.dtype] + list(data.dtypes)\n      ]\n      queue_shapes = [() for _ in types]\n      get_feed_fn = _PandasFeedFn\n    else:\n      raise TypeError(\n          \"data must be either a numpy array or pandas DataFrame if pandas is \"\n          \"installed; got {}\".format(type(data).__name__))\n\n    pad_data = pad_value is not None\n    if pad_data and get_feed_fn is not _GeneratorFeedFn:\n      raise NotImplementedError(\n          \"padding is only available with generator usage\")\n    if shuffle and pad_data:\n      raise NotImplementedError(\n          \"padding and shuffling data at the same time is not implemented\")\n\n    # TODO(jamieas): TensorBoard warnings for all warnings below once available.\n\n    if num_threads > 1 and num_epochs is not None:\n      tf.compat.v1.logging.warn(\n          \"enqueue_data was called with num_epochs and num_threads > 1. \"\n          \"num_epochs is applied per thread, so this will produce more \"\n          \"epochs than you probably intend. \"\n          \"If you want to limit epochs, use one thread.\")\n\n    if shuffle and num_threads > 1 and num_epochs is not None:\n      tf.compat.v1.logging.warn(\n          \"enqueue_data was called with shuffle=True, num_threads > 1, and \"\n          \"num_epochs. This will create multiple threads, all reading the \"\n          \"array/dataframe in order adding to the same shuffling queue; the \"\n          \"results will likely not be sufficiently shuffled.\")\n\n    if not shuffle and num_threads > 1:\n      tf.compat.v1.logging.warn(\n          \"enqueue_data was called with shuffle=False and num_threads > 1. \"\n          \"This will create multiple threads, all reading the \"\n          \"array/dataframe in order. If you want examples read in order, use\"\n          \" one thread; if you want multiple threads, enable shuffling.\")\n\n    if shuffle:\n      min_after_dequeue = int(\n          capacity / 4 if min_after_dequeue is None else min_after_dequeue)\n      queue = tf.queue.RandomShuffleQueue(\n          capacity,\n          min_after_dequeue,\n          dtypes=types,\n          shapes=queue_shapes,\n          seed=seed)\n    elif pad_data:\n      min_after_dequeue = 0  # just for the summary text\n      queue_shapes = list(\n          map(lambda x: tuple(list(x[:-1]) + [None])\n              if len(x) > 0 else x, queue_shapes))\n      queue = tf.queue.PaddingFIFOQueue(\n          capacity, dtypes=types, shapes=queue_shapes)\n    else:\n      min_after_dequeue = 0  # just for the summary text\n      queue = tf.queue.FIFOQueue(capacity, dtypes=types, shapes=queue_shapes)\n\n    enqueue_ops = []\n    feed_fns = []\n\n    for i in range(num_threads):\n      # Note the placeholders have no shapes, so they will accept any\n      # enqueue_size.  enqueue_many below will break them up.\n      placeholders = [tf.compat.v1.placeholder(t) for t in types]\n\n      enqueue_ops.append(queue.enqueue_many(placeholders))\n      seed_i = None if seed is None else (i + 1) * seed\n\n      if not pad_data:\n        feed_fns.append(\n            get_feed_fn(\n                placeholders,\n                data,\n                enqueue_size,\n                random_start=shuffle,\n                seed=seed_i,\n                num_epochs=num_epochs))\n      else:\n        feed_fns.append(\n            get_feed_fn(\n                placeholders,\n                data,\n                enqueue_size,\n                random_start=shuffle,\n                seed=seed_i,\n                num_epochs=num_epochs,\n                pad_value=pad_value))\n\n    runner = fqr._FeedingQueueRunner(  # pylint: disable=protected-access\n        queue=queue,\n        enqueue_ops=enqueue_ops,\n        feed_fns=feed_fns)\n    tf.compat.v1.train.queue_runner.add_queue_runner(runner)\n\n    full = (\n        tf.cast(\n            tf.math.maximum(0,\n                            queue.size() - min_after_dequeue),\n            tf.dtypes.float32) * (1. / (capacity - min_after_dequeue)))\n    # Note that name contains a '/' at the end so we intentionally do not place\n    # a '/' after %s below.\n    summary_name = (\n        \"queue/%sfraction_over_%d_of_%d_full\" %\n        (queue.name, min_after_dequeue, capacity - min_after_dequeue))\n    tf.compat.v1.summary.scalar(summary_name, full)\n    return queue\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/queues/feeding_functions_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests feeding functions using arrays and `DataFrames`.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.inputs.queues import feeding_functions as ff\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n\ndef vals_to_list(a):\n  return {\n      key: val.tolist() if isinstance(val, np.ndarray) else val\n      for key, val in a.items()\n  }\n\n\nclass _FeedingFunctionsTestCase(tf.test.TestCase):\n  \"\"\"Tests for feeding functions.\"\"\"\n\n  def testArrayFeedFnBatchOne(self):\n    array = np.arange(32).reshape([16, 2])\n    placeholders = [\"index_placeholder\", \"value_placeholder\"]\n    aff = ff._ArrayFeedFn(placeholders, array, 1)\n\n    # cycle around a couple times\n    for x in range(0, 100):\n      i = x % 16\n      expected = {\n          \"index_placeholder\": [i],\n          \"value_placeholder\": [[2 * i, 2 * i + 1]]\n      }\n      actual = aff()\n      self.assertEqual(expected, vals_to_list(actual))\n\n  def testArrayFeedFnBatchFive(self):\n    array = np.arange(32).reshape([16, 2])\n    placeholders = [\"index_placeholder\", \"value_placeholder\"]\n    aff = ff._ArrayFeedFn(placeholders, array, 5)\n\n    # cycle around a couple times\n    for _ in range(0, 101, 2):\n      aff()\n\n    expected = {\n        \"index_placeholder\": [15, 0, 1, 2, 3],\n        \"value_placeholder\": [[30, 31], [0, 1], [2, 3], [4, 5], [6, 7]]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testArrayFeedFnBatchTwoWithOneEpoch(self):\n    array = np.arange(5) + 10\n    placeholders = [\"index_placeholder\", \"value_placeholder\"]\n    aff = ff._ArrayFeedFn(placeholders, array, batch_size=2, num_epochs=1)\n\n    expected = {\"index_placeholder\": [0, 1], \"value_placeholder\": [10, 11]}\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n    expected = {\"index_placeholder\": [2, 3], \"value_placeholder\": [12, 13]}\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n    expected = {\"index_placeholder\": [4], \"value_placeholder\": [14]}\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testArrayFeedFnBatchOneHundred(self):\n    array = np.arange(32).reshape([16, 2])\n    placeholders = [\"index_placeholder\", \"value_placeholder\"]\n    aff = ff._ArrayFeedFn(placeholders, array, 100)\n\n    expected = {\n        \"index_placeholder\":\n            list(range(0, 16)) * 6 + list(range(0, 4)),\n        \"value_placeholder\":\n            np.arange(32).reshape([16, 2]).tolist() * 6 +\n            [[0, 1], [2, 3], [4, 5], [6, 7]]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testArrayFeedFnBatchOneHundredWithSmallerArrayAndMultipleEpochs(self):\n    array = np.arange(2) + 10\n    placeholders = [\"index_placeholder\", \"value_placeholder\"]\n    aff = ff._ArrayFeedFn(placeholders, array, batch_size=100, num_epochs=2)\n\n    expected = {\n        \"index_placeholder\": [0, 1, 0, 1],\n        \"value_placeholder\": [10, 11, 10, 11],\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testPandasFeedFnBatchOne(self):\n    if not HAS_PANDAS:\n      return\n    array1 = np.arange(32, 64)\n    array2 = np.arange(64, 96)\n    df = pd.DataFrame({\"a\": array1, \"b\": array2}, index=np.arange(96, 128))\n    placeholders = [\"index_placeholder\", \"a_placeholder\", \"b_placeholder\"]\n    aff = ff._PandasFeedFn(placeholders, df, 1)\n\n    # cycle around a couple times\n    for x in range(0, 100):\n      i = x % 32\n      expected = {\n          \"index_placeholder\": [i + 96],\n          \"a_placeholder\": [32 + i],\n          \"b_placeholder\": [64 + i]\n      }\n      actual = aff()\n      self.assertEqual(expected, vals_to_list(actual))\n\n  def testPandasFeedFnBatchFive(self):\n    if not HAS_PANDAS:\n      return\n    array1 = np.arange(32, 64)\n    array2 = np.arange(64, 96)\n    df = pd.DataFrame({\"a\": array1, \"b\": array2}, index=np.arange(96, 128))\n    placeholders = [\"index_placeholder\", \"a_placeholder\", \"b_placeholder\"]\n    aff = ff._PandasFeedFn(placeholders, df, 5)\n\n    # cycle around a couple times\n    for _ in range(0, 101, 2):\n      aff()\n\n    expected = {\n        \"index_placeholder\": [127, 96, 97, 98, 99],\n        \"a_placeholder\": [63, 32, 33, 34, 35],\n        \"b_placeholder\": [95, 64, 65, 66, 67]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testPandasFeedFnBatchTwoWithOneEpoch(self):\n    if not HAS_PANDAS:\n      return\n    array1 = np.arange(32, 37)\n    array2 = np.arange(64, 69)\n    df = pd.DataFrame({\"a\": array1, \"b\": array2}, index=np.arange(96, 101))\n    placeholders = [\"index_placeholder\", \"a_placeholder\", \"b_placeholder\"]\n    aff = ff._PandasFeedFn(placeholders, df, batch_size=2, num_epochs=1)\n\n    expected = {\n        \"index_placeholder\": [96, 97],\n        \"a_placeholder\": [32, 33],\n        \"b_placeholder\": [64, 65]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n    expected = {\n        \"index_placeholder\": [98, 99],\n        \"a_placeholder\": [34, 35],\n        \"b_placeholder\": [66, 67]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n    expected = {\n        \"index_placeholder\": [100],\n        \"a_placeholder\": [36],\n        \"b_placeholder\": [68]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testPandasFeedFnBatchOneHundred(self):\n    if not HAS_PANDAS:\n      return\n    array1 = np.arange(32, 64)\n    array2 = np.arange(64, 96)\n    df = pd.DataFrame({\"a\": array1, \"b\": array2}, index=np.arange(96, 128))\n    placeholders = [\"index_placeholder\", \"a_placeholder\", \"b_placeholder\"]\n    aff = ff._PandasFeedFn(placeholders, df, 100)\n\n    expected = {\n        \"index_placeholder\": list(range(96, 128)) * 3 + list(range(96, 100)),\n        \"a_placeholder\": list(range(32, 64)) * 3 + list(range(32, 36)),\n        \"b_placeholder\": list(range(64, 96)) * 3 + list(range(64, 68))\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testPandasFeedFnBatchOneHundredWithSmallDataArrayAndMultipleEpochs(self):\n    if not HAS_PANDAS:\n      return\n    array1 = np.arange(32, 34)\n    array2 = np.arange(64, 66)\n    df = pd.DataFrame({\"a\": array1, \"b\": array2}, index=np.arange(96, 98))\n    placeholders = [\"index_placeholder\", \"a_placeholder\", \"b_placeholder\"]\n    aff = ff._PandasFeedFn(placeholders, df, batch_size=100, num_epochs=2)\n\n    expected = {\n        \"index_placeholder\": [96, 97, 96, 97],\n        \"a_placeholder\": [32, 33, 32, 33],\n        \"b_placeholder\": [64, 65, 64, 65]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testOrderedDictNumpyFeedFnBatchTwoWithOneEpoch(self):\n    a = np.arange(32, 37)\n    b = np.arange(64, 69)\n    x = {\"a\": a, \"b\": b}\n    ordered_dict_x = collections.OrderedDict(\n        sorted(x.items(), key=lambda t: t[0]))\n    placeholders = [\"index_placeholder\", \"a_placeholder\", \"b_placeholder\"]\n    aff = ff._OrderedDictNumpyFeedFn(\n        placeholders, ordered_dict_x, batch_size=2, num_epochs=1)\n\n    expected = {\n        \"index_placeholder\": [0, 1],\n        \"a_placeholder\": [32, 33],\n        \"b_placeholder\": [64, 65]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n    expected = {\n        \"index_placeholder\": [2, 3],\n        \"a_placeholder\": [34, 35],\n        \"b_placeholder\": [66, 67]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n    expected = {\n        \"index_placeholder\": [4],\n        \"a_placeholder\": [36],\n        \"b_placeholder\": [68]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testOrderedDictNumpyFeedFnLargeBatchWithSmallArrayAndMultipleEpochs(self):\n    a = np.arange(32, 34)\n    b = np.arange(64, 66)\n    x = {\"a\": a, \"b\": b}\n    ordered_dict_x = collections.OrderedDict(\n        sorted(x.items(), key=lambda t: t[0]))\n    placeholders = [\"index_placeholder\", \"a_placeholder\", \"b_placeholder\"]\n    aff = ff._OrderedDictNumpyFeedFn(\n        placeholders, ordered_dict_x, batch_size=100, num_epochs=2)\n\n    expected = {\n        \"index_placeholder\": [0, 1, 0, 1],\n        \"a_placeholder\": [32, 33, 32, 33],\n        \"b_placeholder\": [64, 65, 64, 65]\n    }\n    actual = aff()\n    self.assertEqual(expected, vals_to_list(actual))\n\n  def testFillArraySmall(self):\n    a = (\n        np.ones(shape=[32, 32], dtype=np.int32).tolist() +\n        np.ones(shape=[32, 36], dtype=np.int32).tolist())\n    actual = np.ones(shape=[64, 36], dtype=np.int32)\n    ff._fill_array(actual, a)\n    expected = np.ones(shape=[64, 36], dtype=np.int32)\n    expected[:32, 32:] = 0\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n  def testFillArrayLarge(self):\n    a = (\n        np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() +\n        np.ones(shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())\n    actual = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)\n    ff._fill_array(actual, a)\n    expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)\n    expected[:8, ..., 32:] = 0\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n  def testFillArraySmallWithSpecifiedValue(self):\n    fill_value = 8\n    a = (\n        np.ones(shape=[32, 32], dtype=np.int32).tolist() +\n        np.ones(shape=[32, 36], dtype=np.int32).tolist())\n    actual = np.ones(shape=[64, 36], dtype=np.int32)\n    ff._fill_array(actual, a, fill_value)\n    expected = np.ones(shape=[64, 36], dtype=np.int32)\n    expected[:32, 32:] = fill_value\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n  def testFillArrayLargeWithSpecifiedValue(self):\n    fill_value = 8\n    a = (\n        np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() +\n        np.ones(shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())\n    actual = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)\n    ff._fill_array(actual, a, fill_value)\n    expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)\n    expected[:8, ..., 32:] = fill_value\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n  def testPadIfNeededSmall(self):\n    a = (\n        np.ones(shape=[32, 32], dtype=np.int32).tolist() +\n        np.ones(shape=[32, 36], dtype=np.int32).tolist())\n    a = list(map(np.array, a))\n    actual = ff._pad_if_needed(a)\n    expected = np.ones(shape=[64, 36], dtype=np.int32)\n    expected[:32, 32:] = 0\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n  def testPadIfNeededLarge(self):\n    a = (\n        np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() +\n        np.ones(shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())\n    a = list(map(np.array, a))\n    actual = ff._pad_if_needed(a)\n    expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)\n    expected[:8, ..., 32:] = 0\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n  def testPadIfNeededSmallWithSpecifiedValue(self):\n    fill_value = 8\n    a = (\n        np.ones(shape=[32, 32], dtype=np.int32).tolist() +\n        np.ones(shape=[32, 36], dtype=np.int32).tolist())\n    a = list(map(np.array, a))\n    actual = ff._pad_if_needed(a, fill_value)\n    expected = np.ones(shape=[64, 36], dtype=np.int32)\n    expected[:32, 32:] = fill_value\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n  def testPadIfNeededLargeWithSpecifiedValue(self):\n    fill_value = 8\n    a = (\n        np.ones(shape=[8, 8, 8, 8, 32], dtype=np.int32).tolist() +\n        np.ones(shape=[8, 8, 8, 8, 36], dtype=np.int32).tolist())\n    a = list(map(np.array, a))\n    actual = ff._pad_if_needed(a, fill_value)\n    expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=np.int32)\n    expected[:8, ..., 32:] = fill_value\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n  def testPadIfNeededSmallWithSpecifiedNonNumericValue(self):\n    fill_value = False\n    a = (\n        np.ones(shape=[32, 32], dtype=bool).tolist() +\n        np.ones(shape=[32, 36], dtype=bool).tolist())\n    a = list(map(np.array, a))\n    actual = ff._pad_if_needed(a, fill_value)\n    expected = np.ones(shape=[64, 36], dtype=bool)\n    expected[:32, 32:] = fill_value\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n  def testPadIfNeededLargeWithSpecifiedNonNumericValue(self):\n    fill_value = False\n    a = (\n        np.ones(shape=[8, 8, 8, 8, 32], dtype=bool).tolist() +\n        np.ones(shape=[8, 8, 8, 8, 36], dtype=bool).tolist())\n    a = list(map(np.array, a))\n    actual = ff._pad_if_needed(a, fill_value)\n    expected = np.ones(shape=[16, 8, 8, 8, 36], dtype=bool)\n    expected[:8, ..., 32:] = fill_value\n    self.assertEqual(expected.tolist(), actual.tolist())\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"A `QueueRunner` that takes a feed function as an argument.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport threading\nimport tensorflow as tf\n\n\nclass _FeedingQueueRunner(tf.compat.v1.train.queue_runner.QueueRunner):\n  \"\"\"A queue runner that allows the feeding of values such as numpy arrays.\"\"\"\n\n  def __init__(self,\n               queue=None,\n               enqueue_ops=None,\n               close_op=None,\n               cancel_op=None,\n               feed_fns=None,\n               queue_closed_exception_types=None):\n    \"\"\"Initialize the queue runner.\n\n    For further documentation, see `queue_runner.py`. Note that\n    `FeedingQueueRunner` does not support construction from protobuffer nor\n    serialization to protobuffer.\n\n    Args:\n      queue: A `Queue`.\n      enqueue_ops: List of enqueue ops to run in threads later.\n      close_op: Op to close the queue. Pending enqueue ops are preserved.\n      cancel_op: Op to close the queue and cancel pending enqueue ops.\n      feed_fns: a list of functions that return a dictionary mapping fed\n        `Tensor`s to values. Must be the same length as `enqueue_ops`.\n      queue_closed_exception_types: Optional tuple of Exception types that\n        indicate that the queue has been closed when raised during an enqueue\n        operation.  Defaults to `(tf.errors.OutOfRangeError,\n        tf.errors.CancelledError)`.\n\n    Raises:\n      ValueError: `feed_fns` is not `None` and has different length than\n        `enqueue_ops`.\n    \"\"\"\n    if queue_closed_exception_types is None:\n      queue_closed_exception_types = (tf.errors.OutOfRangeError,\n                                      tf.errors.CancelledError)\n    super(_FeedingQueueRunner, self).__init__(\n        queue,\n        enqueue_ops,\n        close_op,\n        cancel_op,\n        queue_closed_exception_types=queue_closed_exception_types)\n    if feed_fns is None:\n      self._feed_fns = [None for _ in enqueue_ops]\n    else:\n      if len(feed_fns) != len(enqueue_ops):\n        raise ValueError(\n            \"If feed_fns is not None, it must have the same length as \"\n            \"enqueue_ops.\")\n      self._feed_fns = feed_fns\n\n  # pylint: disable=broad-except\n  def _run(self, sess, enqueue_op, feed_fn, coord=None):\n    \"\"\"Execute the enqueue op in a loop, close the queue in case of error.\n\n    Args:\n      sess: A `Session`.\n      enqueue_op: The `Operation` to run.\n      feed_fn: the feed function to pass to `sess.run`.\n      coord: Optional `Coordinator` object for reporting errors and checking for\n        stop conditions.\n    \"\"\"\n    # TODO(jamieas): Reduce code duplication with `QueueRunner`.\n    if coord:\n      coord.register_thread(threading.current_thread())\n    decremented = False\n    try:\n      while True:\n        if coord and coord.should_stop():\n          break\n        try:\n          feed_dict = None if feed_fn is None else feed_fn()\n          sess.run(enqueue_op, feed_dict=feed_dict)\n        except (tf.errors.OutOfRangeError, tf.errors.CancelledError):\n          # This exception indicates that a queue was closed.\n          with self._lock:\n            self._runs_per_session[sess] -= 1\n            decremented = True\n            if self._runs_per_session[sess] == 0:\n              try:\n                sess.run(self._close_op)\n              except Exception as e:\n                # Intentionally ignore errors from close_op.\n                tf.compat.v1.logging.vlog(1, \"Ignored exception: %s\", str(e))\n            return\n    except Exception as e:\n      # This catches all other exceptions.\n      if coord:\n        coord.request_stop(e)\n      else:\n        tf.compat.v1.logging.error(\"Exception in QueueRunner: %s\", str(e))\n        with self._lock:\n          self._exceptions_raised.append(e)\n        raise\n    finally:\n      # Make sure we account for all terminations: normal or errors.\n      if not decremented:\n        with self._lock:\n          self._runs_per_session[sess] -= 1\n\n  def create_threads(self, sess, coord=None, daemon=False, start=False):\n    \"\"\"Create threads to run the enqueue ops for the given session.\n\n    This method requires a session in which the graph was launched.  It creates\n    a list of threads, optionally starting them.  There is one thread for each\n    op passed in `enqueue_ops`.\n\n    The `coord` argument is an optional coordinator, that the threads will use\n    to terminate together and report exceptions.  If a coordinator is given,\n    this method starts an additional thread to close the queue when the\n    coordinator requests a stop.\n\n    If previously created threads for the given session are still running, no\n    new threads will be created.\n\n    Args:\n      sess: A `Session`.\n      coord: Optional `Coordinator` object for reporting errors and checking\n        stop conditions.\n      daemon: Boolean.  If `True` make the threads daemon threads.\n      start: Boolean.  If `True` starts the threads.  If `False` the caller must\n        call the `start()` method of the returned threads.\n\n    Returns:\n      A list of threads.\n    \"\"\"\n    with self._lock:\n      try:\n        if self._runs_per_session[sess] > 0:\n          # Already started: no new threads to return.\n          return []\n      except KeyError:\n        # We haven't seen this session yet.\n        pass\n      self._runs_per_session[sess] = len(self._enqueue_ops)\n      self._exceptions_raised = []\n\n    ret_threads = [\n        threading.Thread(target=self._run, args=(sess, op, feed_fn, coord))\n        for op, feed_fn in zip(self._enqueue_ops, self._feed_fns)\n    ]\n    if coord:\n      ret_threads.append(\n          threading.Thread(\n              target=self._close_on_stop, args=(sess, self._cancel_op, coord)))\n    for t in ret_threads:\n      if daemon:\n        t.daemon = True\n      if start:\n        t.start()\n    return ret_threads\n\n  def _init_from_proto(self, queue_runner_def):\n    raise NotImplementedError(\n        \"{} does not support initialization from proto.\".format(\n            type(self).__name__))\n\n  def to_proto(self):\n    raise NotImplementedError(\n        \"{} does not support serialization to proto.\".format(\n            type(self).__name__))\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests `FeedingQueueRunner` using arrays and `DataFrames`.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.inputs.queues import feeding_functions as ff\n\ntry:\n  # pylint: disable=g-import-not-at-top\n  import pandas as pd\n  HAS_PANDAS = True\nexcept IOError:\n  # Pandas writes a temporary file during import. If it fails, don't use pandas.\n  HAS_PANDAS = False\nexcept ImportError:\n  HAS_PANDAS = False\n\n\ndef get_rows(array, row_indices):\n  rows = [array[i] for i in row_indices]\n  return np.vstack(rows)\n\n\n@test_util.deprecated_graph_mode_only\nclass FeedingQueueRunnerTestCase(tf.test.TestCase):\n  \"\"\"Tests for `FeedingQueueRunner`.\"\"\"\n\n  def testArrayFeeding(self):\n    with tf.Graph().as_default():\n      array = np.arange(32).reshape([16, 2])\n      q = ff._enqueue_data(array, capacity=100)\n      batch_size = 3\n      dq_op = q.dequeue_many(batch_size)\n      with tf.compat.v1.Session() as sess:\n        coord = tf.train.Coordinator()\n        threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n            sess=sess, coord=coord)\n        for i in range(100):\n          indices = [\n              j % array.shape[0]\n              for j in range(batch_size * i, batch_size * (i + 1))\n          ]\n          expected_dq = get_rows(array, indices)\n          dq = sess.run(dq_op)\n          np.testing.assert_array_equal(indices, dq[0])\n          np.testing.assert_array_equal(expected_dq, dq[1])\n        coord.request_stop()\n        coord.join(threads)\n\n  def testArrayFeedingMultiThread(self):\n    with tf.Graph().as_default():\n      array = np.arange(256).reshape([128, 2])\n      q = ff._enqueue_data(array, capacity=128, num_threads=8, shuffle=True)\n      batch_size = 3\n      dq_op = q.dequeue_many(batch_size)\n      with tf.compat.v1.Session() as sess:\n        coord = tf.train.Coordinator()\n        threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n            sess=sess, coord=coord)\n        for _ in range(100):\n          dq = sess.run(dq_op)\n          indices = dq[0]\n          expected_dq = get_rows(array, indices)\n          np.testing.assert_array_equal(expected_dq, dq[1])\n        coord.request_stop()\n        coord.join(threads)\n\n  def testPandasFeeding(self):\n    if not HAS_PANDAS:\n      return\n    with tf.Graph().as_default():\n      array1 = np.arange(32)\n      array2 = np.arange(32, 64)\n      df = pd.DataFrame({\"a\": array1, \"b\": array2}, index=np.arange(64, 96))\n      q = ff._enqueue_data(df, capacity=100)\n      batch_size = 5\n      dq_op = q.dequeue_many(5)\n      with tf.compat.v1.Session() as sess:\n        coord = tf.train.Coordinator()\n        threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n            sess=sess, coord=coord)\n        for i in range(100):\n          indices = [\n              j % array1.shape[0]\n              for j in range(batch_size * i, batch_size * (i + 1))\n          ]\n          expected_df_indices = df.index[indices]\n          expected_rows = df.iloc[indices]\n          dq = sess.run(dq_op)\n          np.testing.assert_array_equal(expected_df_indices, dq[0])\n          for col_num, col in enumerate(df.columns):\n            np.testing.assert_array_equal(expected_rows[col].values,\n                                          dq[col_num + 1])\n        coord.request_stop()\n        coord.join(threads)\n\n  def testPandasFeedingMultiThread(self):\n    if not HAS_PANDAS:\n      return\n    with tf.Graph().as_default():\n      array1 = np.arange(128, 256)\n      array2 = 2 * array1\n      df = pd.DataFrame({\"a\": array1, \"b\": array2}, index=np.arange(128))\n      q = ff._enqueue_data(df, capacity=128, num_threads=8, shuffle=True)\n      batch_size = 5\n      dq_op = q.dequeue_many(batch_size)\n      with tf.compat.v1.Session() as sess:\n        coord = tf.train.Coordinator()\n        threads = tf.compat.v1.train.queue_runner.start_queue_runners(\n            sess=sess, coord=coord)\n        for _ in range(100):\n          dq = sess.run(dq_op)\n          indices = dq[0]\n          expected_rows = df.iloc[indices]\n          for col_num, col in enumerate(df.columns):\n            np.testing.assert_array_equal(expected_rows[col].values,\n                                          dq[col_num + 1])\n        coord.request_stop()\n        coord.join(threads)\n\n\nif __name__ == \"__main__\":\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/keras_distribute_strategy_test.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for Keras model-to-estimator using tf.distribute.Strategy.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.distribute import strategy_combinations\nfrom tensorflow.python.eager import test\nfrom tensorflow.python.ops.parsing_ops import gen_parsing_ops\nfrom tensorflow_estimator.python.estimator import keras_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\n\n_RANDOM_SEED = 1337\n_TRAIN_SIZE = 200\n_INPUT_SIZE = (10,)\n_NUM_CLASS = 2\n\n\ndef simple_sequential_model():\n  model = tf_keras.models.Sequential()\n  model.add(tf_keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE))\n  model.add(tf_keras.layers.Dropout(0.1))\n  model.add(tf_keras.layers.Dense(_NUM_CLASS, activation='softmax'))\n  return model\n\n\ndef simple_functional_model():\n  a = tf_keras.layers.Input(shape=_INPUT_SIZE)\n  b = tf_keras.layers.Dense(16, activation='relu')(a)\n  b = tf_keras.layers.Dropout(0.1)(b)\n  b = tf_keras.layers.Dense(_NUM_CLASS, activation='softmax')(b)\n  model = tf_keras.models.Model(inputs=[a], outputs=[b])\n  return model\n\n\ndef multi_inputs_multi_outputs_model():\n  input_a = tf_keras.layers.Input(shape=(16,), name='input_a')\n  input_b = tf_keras.layers.Input(shape=(16,), name='input_b')\n  input_m = tf_keras.layers.Input(shape=(8,), dtype='string', name='input_m')\n  dense = tf_keras.layers.Dense(8, name='dense_1')\n\n  interm_a = dense(input_a)\n  # Read m\n  interm_m = tf_keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m)\n  interm_s = tf_keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a])\n  interm_b = dense(input_b)\n  merged = tf_keras.layers.concatenate([interm_s, interm_b], name='merge')\n  output_c = tf_keras.layers.Dense(3, activation='softmax', name='dense_2')(\n      merged)\n  output_d = tf_keras.layers.Dense(2, activation='softmax', name='dense_3')(\n      merged)\n  model = tf_keras.models.Model(\n      inputs=[input_a, input_b, input_m], outputs=[output_c, output_d])\n  model.compile(\n      loss='categorical_crossentropy',\n      optimizer=tf_keras.optimizers.legacy.SGD(learning_rate=0.001),\n      metrics={\n          'dense_2': 'categorical_accuracy',\n          'dense_3': 'categorical_accuracy'\n      })\n  return model\n\n\ndef get_ds_train_input_fn():\n  np.random.seed(_RANDOM_SEED)\n  (x_train, y_train), _ = get_test_data(\n      train_samples=_TRAIN_SIZE,\n      test_samples=50,\n      input_shape=_INPUT_SIZE,\n      num_classes=_NUM_CLASS)\n  y_train = tf_keras.utils.to_categorical(y_train)\n\n  dataset = tf.compat.v1.data.Dataset.from_tensor_slices((x_train, y_train))\n  dataset = dataset.batch(32)\n  return dataset\n\n\ndef get_ds_test_input_fn():\n  np.random.seed(_RANDOM_SEED)\n  _, (x_test, y_test) = get_test_data(\n      train_samples=_TRAIN_SIZE,\n      test_samples=50,\n      input_shape=_INPUT_SIZE,\n      num_classes=_NUM_CLASS)\n  y_test = tf_keras.utils.to_categorical(y_test)\n\n  dataset = tf.compat.v1.data.Dataset.from_tensor_slices((x_test, y_test))\n  dataset = dataset.batch(32)\n  return dataset\n\n\ndef get_multi_inputs_multi_outputs_data():\n  (a_train, c_train), (a_test, c_test) = get_test_data(\n      train_samples=_TRAIN_SIZE,\n      test_samples=50,\n      input_shape=(16,),\n      num_classes=3,\n      random_seed=_RANDOM_SEED)\n  (b_train, d_train), (b_test, d_test) = get_test_data(\n      train_samples=_TRAIN_SIZE,\n      test_samples=50,\n      input_shape=(16,),\n      num_classes=2,\n      random_seed=_RANDOM_SEED)\n  (m_train, _), (m_test, _) = get_test_data(\n      train_samples=_TRAIN_SIZE,\n      test_samples=50,\n      input_shape=(8,),\n      num_classes=2,\n      random_seed=_RANDOM_SEED)\n\n  c_train = tf_keras.utils.to_categorical(c_train)\n  c_test = tf_keras.utils.to_categorical(c_test)\n  d_train = tf_keras.utils.to_categorical(d_train)\n  d_test = tf_keras.utils.to_categorical(d_test)\n\n  train_data = {\n      'input_a': a_train,\n      'input_b': b_train,\n      'input_m': m_train,\n      'output_c': c_train,\n      'output_d': d_train\n  }\n  test_data = {\n      'input_a': a_test,\n      'input_b': b_test,\n      'input_m': m_test,\n      'output_c': c_test,\n      'output_d': d_test\n  }\n\n  return (train_data, test_data)\n\n\nclass TestEstimatorDistributionStrategy(tf.test.TestCase,\n                                        parameterized.TestCase):\n\n  def setUp(self):\n    super(TestEstimatorDistributionStrategy, self).setUp()\n    strategy_combinations.set_virtual_cpus_to_at_least(3)\n    self._base_dir = os.path.join(self.get_temp_dir(),\n                                  'keras_to_estimator_strategy_test')\n    tf.compat.v1.gfile.MakeDirs(self._base_dir)\n    self._config = run_config_lib.RunConfig(\n        tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)\n\n  def tearDown(self):\n    super(TestEstimatorDistributionStrategy, self).tearDown()\n    tf.compat.v1.summary.FileWriterCache.clear()\n    if os.path.isdir(self._base_dir):\n      tf.compat.v1.gfile.DeleteRecursively(self._base_dir)\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],\n          mode=['graph'],\n          cloning=[True, False],\n      )\n  )\n  def test_train_functional_with_distribution_strategy(self, distribution,\n                                                       cloning):\n    keras_model = simple_functional_model()\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        metrics=[tf_keras.metrics.CategoricalAccuracy()],\n        optimizer=tf_keras.optimizers.legacy.RMSprop(learning_rate=0.01),\n        cloning=cloning)\n    config = run_config_lib.RunConfig(\n        tf_random_seed=_RANDOM_SEED,\n        model_dir=self._base_dir,\n        train_distribute=distribution,\n        eval_distribute=distribution)\n    with self.cached_session():\n      est_keras = keras_lib.model_to_estimator(\n          keras_model=keras_model, config=config)\n      before_eval_results = est_keras.evaluate(\n          input_fn=get_ds_test_input_fn, steps=1)\n      est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)\n      after_eval_results = est_keras.evaluate(\n          input_fn=get_ds_test_input_fn, steps=1)\n      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])\n\n    tf.compat.v1.summary.FileWriterCache.clear()\n    tf.compat.v1.gfile.DeleteRecursively(self._config.model_dir)\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],\n          mode=['graph'],\n          cloning=[True, False],\n      )\n  )\n  def test_train_sequential_with_distribution_strategy(self, distribution,\n                                                       cloning):\n    keras_model = simple_sequential_model()\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        metrics=[tf_keras.metrics.CategoricalAccuracy()],\n        optimizer=tf_keras.optimizers.legacy.RMSprop(learning_rate=0.01),\n        cloning=cloning)\n    config = run_config_lib.RunConfig(\n        tf_random_seed=_RANDOM_SEED,\n        model_dir=self._base_dir,\n        train_distribute=distribution)\n    with self.cached_session():\n      est_keras = keras_lib.model_to_estimator(\n          keras_model=keras_model, config=config)\n      before_eval_results = est_keras.evaluate(\n          input_fn=get_ds_test_input_fn, steps=1)\n      est_keras.train(input_fn=get_ds_train_input_fn, steps=_TRAIN_SIZE / 16)\n      after_eval_results = est_keras.evaluate(\n          input_fn=get_ds_test_input_fn, steps=1)\n      self.assertLess(after_eval_results['loss'], before_eval_results['loss'])\n\n    tf.compat.v1.summary.FileWriterCache.clear()\n    tf.compat.v1.gfile.DeleteRecursively(self._config.model_dir)\n\n  @tf.compat.v2.__internal__.distribute.combinations.generate(\n      tf.compat.v2.__internal__.test.combinations.combine(\n          distribution=[strategy_combinations.mirrored_strategy_with_two_cpus],\n          mode=['graph'],\n      )\n  )\n  def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self, distribution):\n    train_data, test_data = get_multi_inputs_multi_outputs_data()\n\n    def train_input_fn():\n      input_dict = {\n          'input_a': train_data['input_a'],\n          'input_b': train_data['input_b'],\n          'input_m': train_data['input_m'].astype(str)\n      }\n      output_dict = {\n          'dense_2': train_data['output_c'],\n          'dense_3': train_data['output_d']\n      }\n      return tf.compat.v1.data.Dataset.from_tensor_slices(\n          (input_dict, output_dict)).batch(16)\n\n    def eval_input_fn():\n      input_dict = {\n          'input_a': test_data['input_a'],\n          'input_b': test_data['input_b'],\n          'input_m': test_data['input_m'].astype(str)\n      }\n      output_dict = {\n          'dense_2': test_data['output_c'],\n          'dense_3': test_data['output_d']\n      }\n      return tf.compat.v1.data.Dataset.from_tensor_slices(\n          (input_dict, output_dict)).batch(16)\n\n    self.do_test_multi_inputs_multi_outputs_with_input_fn(\n        distribution, train_input_fn, eval_input_fn)\n\n  def do_test_multi_inputs_multi_outputs_with_input_fn(self, distribution,\n                                                       train_input_fn,\n                                                       eval_input_fn):\n    config = run_config_lib.RunConfig(\n        tf_random_seed=_RANDOM_SEED,\n        model_dir=self._base_dir,\n        train_distribute=distribution)\n    with self.cached_session():\n      model = multi_inputs_multi_outputs_model()\n      est_keras = keras_lib.model_to_estimator(keras_model=model, config=config)\n      baseline_eval_results = est_keras.evaluate(\n          input_fn=eval_input_fn, steps=1)\n      est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)\n      eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)\n      self.assertLess(eval_results['loss'], baseline_eval_results['loss'])\n\n\ndef get_test_data(train_samples,\n                  test_samples,\n                  input_shape,\n                  num_classes,\n                  random_seed=None):\n  if random_seed is not None:\n    np.random.seed(random_seed)\n  num_sample = train_samples + test_samples\n  templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)\n  y = np.random.randint(0, num_classes, size=(num_sample,))\n  x = np.zeros((num_sample,) + input_shape, dtype=np.float32)\n  for i in range(num_sample):\n    x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)\n  return ((x[:train_samples], y[:train_samples]),\n          (x[train_samples:], y[train_samples:]))\n\n\nif __name__ == '__main__':\n  test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/keras_lib.py",
    "content": "# Copyright 2015 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n# pylint: disable=protected-access\n\"\"\"Home of estimator related functions.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport os\nimport re\nfrom absl import logging\nimport tensorflow as tf\n\nfrom tensorflow.python.checkpoint import checkpoint as trackable_util\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.export import export_lib\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v2\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\n\n\n_DEFAULT_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n\n\nclass FormattedKeyError(KeyError):\n  \"\"\"KeyError with formatted error message.\n\n  Python's `KeyError` has special casing around formatting\n  (see https://bugs.python.org/issue2651). Use this class when the error\n  message has newlines and other special format characters.\n\n  Needed by https://github.com/tensorflow/tensorflow/issues/36857.\n  \"\"\"\n\n  def __init__(self, message):\n    self.message = message\n\n  def __str__(self):\n    return self.message\n\n\ndef _cast_tensor_to_floatx(x):\n  \"\"\"Cast tensor to keras's floatx dtype if it is not already the same dtype.\"\"\"\n  if x.dtype == tf_keras.backend.floatx():\n    return x\n  else:\n    return tf.cast(x, tf_keras.backend.floatx())\n\n\ndef _convert_tensor(x):\n  \"\"\"Create or cast tensor if needed.\"\"\"\n  if not tf.is_tensor(x):\n    # x is a numpy array\n    x = tf.compat.v1.convert_to_tensor_or_sparse_tensor(x)\n  return x\n\n\ndef _any_weight_initialized(keras_model):\n  \"\"\"Check if any weights has been initialized in the Keras model.\n\n  Args:\n    keras_model: An instance of compiled keras model.\n\n  Returns:\n    boolean, True if at least one weight has been initialized, else False.\n    Currently keras initialize all weights at get_session().\n  \"\"\"\n  if keras_model is None:\n    return False\n  if tf.compat.v1.executing_eagerly_outside_functions():\n    return True\n  for layer in keras_model.layers:\n    for weight in layer.weights:\n      if hasattr(weight, '_keras_initialized'):\n        return True\n  return False\n\n\ndef _convert_estimator_io_to_keras(keras_model, features, labels):\n  \"\"\"Converts estimator features and labels to keras input and target tensors.\n\n  Args:\n    keras_model: a compiled `tf_keras.Model` instance, used to determine the\n      order of the returned lists.\n    features: Dict of tensors or `None`.\n    labels: Dict of tensors, a single tensor, or `None`.\n\n  Returns:\n    Tuple of (\n      list of input tensors or `None`,\n      list of target tensors or `None`,\n      list of sample weight tensors or `None`)\n    The order of tensors is determined by the order set in the keras model.\n  \"\"\"\n\n  def _to_ordered_tensor_list(obj, key_order, obj_name, order_name):\n    \"\"\"Convert obj to an ordered list of tensors.\n\n    Args:\n      obj: List, dict, or single tensor. May be `None`.\n      key_order: List of strings with the order to return (used if obj is a\n        dict).\n      obj_name: String name of object (e.g. \"features\" or \"labels\")\n      order_name: String name of the key order (e.g. \"inputs\" or \"outputs\")\n\n    Returns:\n      List of tensors, or `None`\n\n    Raises:\n      KeyError: If obj has invalid keys.\n    \"\"\"\n    if obj is None:\n      return None\n    elif isinstance(obj, (list, tuple)):\n      return [_convert_tensor(x) for x in obj]\n    elif isinstance(obj, dict):\n      # Ensure that keys in key_order are contained in obj keys.\n      # One can provide more data keys described in obj, as long as the keys\n      # requested by model are provided.\n      different_keys = set(key_order) - set(obj.keys())\n\n      if different_keys:\n        raise FormattedKeyError(\n            'The dictionary passed into {obj_name} does not cover requested '\n            '{order_name} keys defined in the keras model.'\n            '\\n\\tExpected keys: {order_keys}'\n            '\\n\\t{obj_name} keys: {obj_keys}'\n            '\\n\\tMissed keys: {different_keys}'.format(\n                order_name=order_name,\n                order_keys=set(key_order),\n                obj_name=obj_name,\n                obj_keys=set(obj.keys()),\n                different_keys=different_keys))\n\n      return [_convert_tensor(obj[key]) for key in key_order]\n    else:  # Assume obj is a tensor.\n      return [_convert_tensor(obj)]\n\n  features, sample_weight_tensors = _extract_sample_weight_tensors(features)\n  input_names = None\n  output_names = None\n  if isinstance(features, dict):\n    input_names = (\n        keras_model.input_names if keras_model._is_graph_network else\n        ['input_%d' % i for i in range(1,\n                                       len(features) + 1)])\n  if isinstance(labels, dict):\n    output_names = (\n        keras_model.output_names if keras_model._is_graph_network else\n        ['output_%d' % i for i in range(1,\n                                        len(labels) + 1)])\n\n  if isinstance(keras_model.inputs, dict):\n    # Keep input tensors as a dict if keras_model is built with dict input.\n    input_tensors = {\n        k: _convert_tensor(features[k])\n        for (k, v) in keras_model.inputs.items()\n    }\n  elif keras_model.inputs is None and isinstance(features, dict):\n    # Keep input tensors as a dict if keras_model input structure is unknown.\n    input_tensors = {k: _convert_tensor(v) for (k, v) in features.items()}\n  else:\n    # converting input tensors into sorted list.\n    input_tensors = _to_ordered_tensor_list(features, input_names, 'features',\n                                            'inputs')\n  target_tensors = _to_ordered_tensor_list(labels, output_names, 'labels',\n                                           'outputs')\n\n  return input_tensors, target_tensors, sample_weight_tensors\n\n\ndef _extract_sample_weight_tensors(features):\n  if isinstance(features, dict) and set(\n      features.keys()) == {'features', 'sample_weights'}:\n    feature_tensor = features['features']\n    sample_weight_tensors = features['sample_weights']\n  else:\n    feature_tensor = features\n    sample_weight_tensors = None\n  return feature_tensor, sample_weight_tensors\n\n\ndef _clone_and_build_model(mode,\n                           keras_model,\n                           custom_objects,\n                           features=None,\n                           labels=None,\n                           optimizer_config=None):\n  \"\"\"Clone and build the given keras_model.\n\n  Args:\n    mode: training mode.\n    keras_model: an instance of compiled keras model.\n    custom_objects: Dictionary for custom objects.\n    features: Dict of tensors.\n    labels: Dict of tensors, or single tensor instance.\n    optimizer_config: Optimizer config dictionary, returned by\n      `optimizer.get_config()`. This is used when cloning a model with an\n      optimizer. Since `_clone_and_build_model` is called in a different graph\n      and session from the model, `optimizer.get_config()` may raise an error\n      during the attempt to serialize the optimizer hyperparameter values.\n\n  Returns:\n    The newly built model.\n  \"\"\"\n  # Set to True during training, False for inference or testing.\n  tf_keras.backend.set_learning_phase(mode == ModeKeys.TRAIN)\n  input_tensors, target_tensors, sample_weight_tensors = (\n      _convert_estimator_io_to_keras(keras_model, features, labels))\n\n  compile_clone = (mode != ModeKeys.PREDICT)\n\n  global_step = None\n  if compile_clone:\n    # Set iterations to the global step created by tf.train.create_global_step()\n    # which is automatically run in the estimator framework.\n    global_step = tf.compat.v1.train.get_or_create_global_step()\n    tf_keras_v2.__internal__.backend.track_variable(global_step)\n\n  clone = tf_keras_v2.__internal__.models.clone_and_build_model(\n      keras_model,\n      input_tensors,\n      target_tensors,\n      custom_objects,\n      compile_clone=compile_clone,\n      in_place_reset=(not keras_model._is_graph_network),\n      optimizer_iterations=global_step,\n      optimizer_config=optimizer_config)\n\n  if sample_weight_tensors is not None:\n    sample_weight_tensors = standardize_sample_weights(\n        sample_weight_tensors, clone.output_names)\n    # Update calculated loss (model.total_loss) to include sample weights.\n    clone._compile_weights_loss_and_weighted_metrics(sample_weight_tensors)\n  return clone\n\n\ndef _convert_keras_metrics_to_estimator(model, metric_names_map=None):\n  \"\"\"Convert metrics from a Keras model to ops used by the Estimator framework.\n\n  Args:\n    model: A `tf_keras.Model` object.\n    metric_names_map: Optional dictionary mapping Keras model output metric\n      names to custom names.\n\n  Returns:\n    Dictionary mapping metric names to tuples of (value, update) ops. May return\n    `None` if the model does not contain any metrics.\n  \"\"\"\n  if not getattr(model, '_compile_metrics', None):\n    return None\n\n  # We are not using model.metrics here because we want to exclude the metrics\n  # added using `add_metric` API.\n  compiled_metrics = model._compile_metric_functions\n\n  if metric_names_map:\n    custom_map_keys = set(metric_names_map.keys())\n    expected_keys = {m.name for m in compiled_metrics}\n    unknown = expected_keys.difference(custom_map_keys)\n    if unknown:\n      raise ValueError(\n          'Invalid `metric_names_map`. '\n          'The following keras model metric names:\"{}\" do not exist in '\n          'the `metric_names_map` dictionary'.format(list(unknown)))\n\n    extra = custom_map_keys.difference(expected_keys)\n    if extra:\n      raise ValueError('Invalid `metric_names_map`. '\n                       'There are unexpected keys in the `metric_names_map` '\n                       'dictionary. Expected keys: {}, Received: {}'.format(\n                           list(expected_keys), list(extra)))\n\n    return {metric_names_map[m.name]: m for m in compiled_metrics}\n  else:\n    return {m.name: m for m in compiled_metrics}\n\n\ndef _create_keras_model_fn(keras_model,\n                           custom_objects=None,\n                           save_object_ckpt=False,\n                           metric_names_map=None,\n                           export_outputs=None):\n  \"\"\"Creates model_fn for keras Estimator.\n\n  Args:\n    keras_model: an instance of compiled keras model.\n    custom_objects: Dictionary for custom objects.\n    save_object_ckpt: Whether to save an object-based checkpoint.\n    metric_names_map: Optional dictionary mapping Keras model output metric\n      names to custom names.\n    export_outputs: Optional dictionary mapping custom names to a subclass of\n      `tf.estimator.export.ExportOutput`.\n\n  Returns:\n    The model_fn for a keras Estimator.\n  \"\"\"\n  if isinstance(keras_model.optimizer,\n                tf_keras.optimizers.experimental.Optimizer):\n    # Experimental optimizer cannot work with estimator, so we convert it to\n    # legacy optimizer.\n    if tf.executing_eagerly():\n      logging.warning(\n          'You are using `tf_keras.optimizers.experimental.Optimizer` in TF '\n          'estimator, which only supports '\n          '`tf_keras.optimizers.legacy.Optimizer`. Automatically converting '\n          'your optimizer to `tf_keras.optimizers.legacy.Optimizer`.')\n      opt = tf_keras.__internal__.optimizers.convert_to_legacy_optimizer(\n          keras_model.optimizer)\n      keras_model.optimizer = opt\n    else:\n      raise ValueError('Please set your optimizer as an instance of '\n                       '`tf_keras.optimizers.legacy.Optimizer`, e.g., '\n                       '`tf_keras.optimizers.legacy.Adam`. Received optimizer '\n                       f'type: {type(keras_model.optimizer)}.')\n  # Get optimizer config in the current context (since model_fn is called in the\n  # estimator graph and session). OptimizerV2 objects serialize variable/tensor\n  # hyperparameters in their configs, resulting to wrong-session errors during\n  # model cloning.\n  try:\n    if isinstance(keras_model.optimizer, (tuple, list)):\n      optimizer_config = [opt.get_config() for opt in keras_model.optimizer]\n    else:\n      optimizer_config = keras_model.optimizer.get_config()\n  except (NotImplementedError, AttributeError):\n    # TFOptimizers and other custom optimizers do not have a config.\n    optimizer_config = None\n\n  def model_fn(features, labels, mode):\n    \"\"\"model_fn for keras Estimator.\"\"\"\n    model = _clone_and_build_model(\n        mode=mode,\n        keras_model=keras_model,\n        custom_objects=custom_objects,\n        features=features,\n        labels=labels,\n        optimizer_config=optimizer_config)\n    model_output_names = []\n    # We need to make sure that the output names of the last layer in the model\n    # is the same for each of the cloned models. This is required for mirrored\n    # strategy when we call regroup.\n    if tf.distribute.has_strategy():\n      for name in model.output_names:\n        name = re.compile(r'_\\d$').sub('', name)\n        model_output_names.append(name)\n    else:\n      model_output_names = model.output_names\n\n    # Get inputs to EstimatorSpec\n    predictions = dict(zip(model_output_names, model.outputs))\n\n    loss = None\n    train_op = None\n    eval_metric_ops = None\n\n    # Set loss and metric only during train and evaluate.\n    if mode is not ModeKeys.PREDICT:\n      if mode is ModeKeys.TRAIN:\n        model._make_train_function()  # pylint: disable=protected-access\n      else:\n        model._make_test_function()  # pylint: disable=protected-access\n      loss = model.total_loss\n\n      eval_metric_ops = _convert_keras_metrics_to_estimator(\n          model, metric_names_map)\n\n    # Set train_op only during train.\n    if mode is ModeKeys.TRAIN:\n      train_op = model.train_function.updates_op\n\n    if (not model._is_graph_network and\n        hasattr(keras_model, '_original_attributes_cache') and\n        keras_model._original_attributes_cache is not None):\n      # To avoid `model_fn` being destructive for the initial model argument.\n      (tf_keras_v2.__internal__.models.\n       in_place_subclassed_model_state_restoration(keras_model))\n\n    scaffold = None\n    if save_object_ckpt:\n      model._track_trackable(tf.compat.v1.train.get_global_step(),\n                             'estimator_global_step')\n      # Create saver that maps variable names to object-checkpoint keys.\n      object_graph = tf.compat.v2.__internal__.tracking.ObjectGraphView(model)\n      var_list = object_graph.frozen_saveable_objects()\n      saver = tf.compat.v1.train.Saver(var_list=var_list, sharded=True)\n      saver._object_restore_saver = trackable_util.frozen_saver(model)\n      scaffold = tf.compat.v1.train.Scaffold(saver=saver)\n\n    final_export_outputs = {\n        _DEFAULT_SERVING_KEY: export_lib.PredictOutput(predictions)\n    }\n    if export_outputs is not None:\n      different_keys = set(export_outputs.keys()) - set(model.output_names)\n      if different_keys:\n        raise FormattedKeyError(\n            'The list passed into {obj_name} does not cover requested '\n            '{order_name} keys defined in the keras model.'\n            '\\n\\tExpected keys: {order_keys}'\n            '\\n\\t{obj_name} keys: {obj_keys}'\n            '\\n\\tMissed keys: {different_keys}'.format(\n                order_name=export_outputs,\n                order_keys=set(export_outputs.keys()),\n                obj_name=model.output_names,\n                obj_keys=set(model.output_names),\n                different_keys=different_keys))\n      for key, export_output_cls in export_outputs.items():\n        final_export_outputs[key] = export_output_cls(predictions[key])\n\n    return model_fn_lib.EstimatorSpec(\n        mode=mode,\n        predictions=predictions,\n        loss=loss,\n        train_op=train_op,\n        eval_metric_ops=eval_metric_ops,\n        export_outputs=final_export_outputs,\n        scaffold=scaffold)\n\n  return model_fn\n\n\ndef _save_first_checkpoint(keras_model, custom_objects, config,\n                           save_object_ckpt):\n  \"\"\"Save first checkpoint for the keras Estimator.\n\n  Args:\n    keras_model: an instance of compiled keras model.\n    custom_objects: Dictionary for custom objects.\n    config: Estimator config.\n    save_object_ckpt: Whether to save an object-based checkpoint.\n\n  Returns:\n    The path where keras model checkpoint is saved.\n  \"\"\"\n  # save checkpoint into subdirectory to allow warm start\n  keras_model_dir = os.path.join(config.model_dir, 'keras')\n  # Load weights and save to checkpoint if there is no checkpoint\n  latest_path = tf.train.latest_checkpoint(keras_model_dir)\n  if not latest_path:\n    keras_weights = None\n    if _any_weight_initialized(keras_model):\n      keras_weights = keras_model.get_weights()\n    if not tf.compat.v1.gfile.IsDirectory(keras_model_dir):\n      tf.compat.v1.gfile.MakeDirs(keras_model_dir)\n    with tf.Graph().as_default():\n      tf.compat.v1.random.set_random_seed(config.tf_random_seed)\n      tf.compat.v1.train.create_global_step()\n      model = _clone_and_build_model(ModeKeys.TRAIN, keras_model,\n                                     custom_objects)\n\n      # Init the train_function outside of the context of session. This is due\n      # to the fact that train function will update the graph by adding backprop\n      # parts. This will potentially trying to update the node in forward graph\n      # which will fail if it is done within same session.\n      # Always create the train_function here since the model is just cloned.\n      # See https://github.com/tensorflow/tensorflow/issues/27750 for details.\n      model._make_train_function()  # pylint: disable=protected-access\n\n      # save to checkpoint\n      with tf.compat.v1.Session(config=config.session_config) as sess:\n        if keras_weights:\n          model.set_weights(keras_weights)\n        # model._make_train_function() will potentially create the optimizer\n        # variable, which will require another variable initialization.\n        tf_keras_v2.__internal__.backend.initialize_variables(sess)\n\n        if save_object_ckpt:\n          model._track_trackable(  # pylint: disable=protected-access\n              tf.compat.v1.train.get_global_step(), 'estimator_global_step')\n          latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')\n          model.save_weights(latest_path)\n        else:\n          saver = tf.compat.v1.train.Saver()\n          latest_path = os.path.join(keras_model_dir, 'keras_model.ckpt')\n          saver.save(sess, latest_path)\n\n  return latest_path\n\n\ndef _get_file_from_google_storage(keras_model_path, model_dir):\n  \"\"\"Get file from google storage and download to local file.\n\n  Args:\n    keras_model_path: a google storage path for compiled keras model.\n    model_dir: the directory from estimator config.\n\n  Returns:\n    The path where keras model is saved.\n\n  Raises:\n    ValueError: if storage object name does not end with .h5.\n  \"\"\"\n  try:\n    from google.cloud import storage  # pylint:disable=g-import-not-at-top\n  except ImportError:\n    raise TypeError('Could not save model to Google cloud storage; please '\n                    'install `google-cloud-storage` via '\n                    '`pip install google-cloud-storage`.')\n  storage_client = storage.Client()\n  path, blob_name = os.path.split(keras_model_path)\n  _, bucket_name = os.path.split(path)\n  keras_model_dir = os.path.join(model_dir, 'keras')\n  if not tf.compat.v1.gfile.Exists(keras_model_dir):\n    tf.compat.v1.gfile.MakeDirs(keras_model_dir)\n  file_name = os.path.join(keras_model_dir, 'keras_model.h5')\n  try:\n    blob = storage_client.get_bucket(bucket_name).blob(blob_name)\n    blob.download_to_filename(file_name)\n  except:\n    raise ValueError('Failed to download keras model, please check '\n                     'environment variable GOOGLE_APPLICATION_CREDENTIALS '\n                     'and model path storage.googleapis.com/{bucket}/{object}.')\n  tf.compat.v1.logging.info('Saving model to {}'.format(file_name))\n  del storage_client\n  return file_name\n\n\ndef model_to_estimator(keras_model=None,\n                       keras_model_path=None,\n                       custom_objects=None,\n                       model_dir=None,\n                       config=None,\n                       checkpoint_format=None,\n                       use_v2_estimator=False,\n                       metric_names_map=None,\n                       export_outputs=None):\n  \"\"\"Constructs an `Estimator` instance from given keras model.\n\n  If you use infrastructure or other tooling that relies on Estimators, you can\n  still build a Keras model and use model_to_estimator to convert the Keras\n  model to an Estimator for use with downstream systems.\n\n  For usage example, please see:\n  [Creating estimators from Keras\n  Models](https://www.tensorflow.org/guide/estimator#create_an_estimator_from_a_keras_model).\n\n  Sample Weights:\n  Estimators returned by `model_to_estimator` are configured so that they can\n  handle sample weights (similar to `keras_model.fit(x, y, sample_weights)`).\n\n  To pass sample weights when training or evaluating the Estimator, the first\n  item returned by the input function should be a dictionary with keys\n  `features` and `sample_weights`. Example below:\n\n  ```python\n  keras_model = tf_keras.Model(...)\n  keras_model.compile(...)\n\n  estimator = tf_keras.estimator.model_to_estimator(keras_model)\n\n  def input_fn():\n    return dataset_ops.Dataset.from_tensors(\n        ({'features': features, 'sample_weights': sample_weights},\n         targets))\n\n  estimator.train(input_fn, steps=1)\n  ```\n\n  Example with customized export signature:\n  ```python\n  inputs = {'a': tf_keras.Input(..., name='a'),\n            'b': tf_keras.Input(..., name='b')}\n  outputs = {'c': tf_keras.layers.Dense(..., name='c')(inputs['a']),\n             'd': tf_keras.layers.Dense(..., name='d')(inputs['b'])}\n  keras_model = tf_keras.Model(inputs, outputs)\n  keras_model.compile(...)\n  export_outputs = {'c': tf.estimator.export.RegressionOutput,\n                    'd': tf.estimator.export.ClassificationOutput}\n\n  estimator = tf_keras.estimator.model_to_estimator(\n      keras_model, export_outputs=export_outputs)\n\n  def input_fn():\n    return dataset_ops.Dataset.from_tensors(\n        ({'features': features, 'sample_weights': sample_weights},\n         targets))\n\n  estimator.train(input_fn, steps=1)\n  ```\n\n  Note: We do not support creating weighted metrics in Keras and converting them\n  to weighted metrics in the Estimator API using `model_to_estimator`.\n  You will have to create these metrics directly on the estimator spec using the\n  `add_metrics` function.\n\n  Args:\n    keras_model: A compiled Keras model object. This argument is mutually\n      exclusive with `keras_model_path`. Estimator's `model_fn` uses the\n      structure of the model to clone the model. Defaults to `None`.\n    keras_model_path: Path to a compiled Keras model saved on disk, in HDF5\n      format, which can be generated with the `save()` method of a Keras model.\n      This argument is mutually exclusive with `keras_model`.\n      Defaults to `None`.\n    custom_objects: Dictionary for cloning customized objects. This is\n      used with classes that is not part of this pip package. For example, if\n      user maintains a `relu6` class that inherits from `tf_keras.layers.Layer`,\n      then pass `custom_objects={'relu6': relu6}`. Defaults to `None`.\n    model_dir: Directory to save `Estimator` model parameters, graph, summary\n      files for TensorBoard, etc. If unset a directory will be created with\n      `tempfile.mkdtemp`\n    config: `RunConfig` to config `Estimator`. Allows setting up things in\n      `model_fn` based on configuration such as `num_ps_replicas`, or\n      `model_dir`. Defaults to `None`. If both `config.model_dir` and the\n      `model_dir` argument (above) are specified the `model_dir` **argument**\n      takes precedence.\n    checkpoint_format: Sets the format of the checkpoint saved by the estimator\n      when training. May be `saver` or `checkpoint`, depending on whether to\n      save checkpoints from `tf.compat.v1.train.Saver` or `tf.train.Checkpoint`.\n      The default is `checkpoint`. Estimators use name-based `tf.train.Saver`\n      checkpoints, while Keras models use object-based checkpoints from\n      `tf.train.Checkpoint`. Currently, saving object-based checkpoints from\n      `model_to_estimator` is only supported by Functional and Sequential\n      models.\n    use_v2_estimator: Whether to convert the model to a V2 Estimator or V1\n      Estimator. Defaults to `False`.\n    metric_names_map: Optional dictionary mapping Keras model output metric\n      names to custom names. This can be used to override the default Keras\n      model output metrics names in a multi IO model use case and provide custom\n      names for the `eval_metric_ops` in Estimator.\n      The Keras model metric names can be obtained using `model.metrics_names`\n      excluding any loss metrics such as total loss and output losses.\n      For example, if your Keras model has two outputs `out_1` and `out_2`,\n      with `mse` loss and `acc` metric, then `model.metrics_names` will be\n      `['loss', 'out_1_loss', 'out_2_loss', 'out_1_acc', 'out_2_acc']`.\n      The model metric names excluding the loss metrics will be\n      `['out_1_acc', 'out_2_acc']`.\n    export_outputs: Optional dictionary. This can be used to override the\n      default Keras model output exports in a multi IO model use case and\n      provide custom names for the `export_outputs` in\n      `tf.estimator.EstimatorSpec`. Default is None, which is equivalent to\n      {'serving_default': `tf.estimator.export.PredictOutput`}.\n      A dict `{name: output}` where:\n        * name: An arbitrary name for this output. This becomes the signature\n          name in the SavedModel.\n        * output: an `ExportOutput` object such as `ClassificationOutput`,\n          `RegressionOutput`, or `PredictOutput`. Single-headed models only need\n          to specify one entry in this dictionary. Multi-headed models should\n          specify one entry for each head, one of which must be named using\n          `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`.\n          If no entry is provided, a default `PredictOutput` mapping to\n          `predictions` will be created.\n\n  Returns:\n    An Estimator from given keras model.\n\n  Raises:\n    ValueError: If neither keras_model nor keras_model_path was given.\n    ValueError: If both keras_model and keras_model_path was given.\n    ValueError: If the keras_model_path is a GCS URI.\n    ValueError: If keras_model has not been compiled.\n    ValueError: If an invalid checkpoint_format was given.\n  \"\"\"\n\n  if not (keras_model or keras_model_path):\n    raise ValueError(\n        'Either `keras_model` or `keras_model_path` needs to be provided.')\n  if keras_model and keras_model_path:\n    raise ValueError(\n        'Please specity either `keras_model` or `keras_model_path`, '\n        'but not both.')\n\n  if keras_model:\n    _assert_valid_model(keras_model, custom_objects)\n\n  config = estimator_lib.maybe_overwrite_model_dir_and_session_config(\n      config, model_dir)\n  if not keras_model:\n    if keras_model_path.startswith(\n        'gs://') or 'storage.googleapis.com' in keras_model_path:\n      keras_model_path = _get_file_from_google_storage(keras_model_path,\n                                                       config.model_dir)\n    tf.compat.v1.logging.info('Loading models from %s', keras_model_path)\n    keras_model = tf_keras.models.load_model(keras_model_path)\n  else:\n    tf.compat.v1.logging.info('Using the Keras model provided.')\n    keras_model = keras_model\n\n  if checkpoint_format is None or checkpoint_format == 'checkpoint':\n    if not (keras_model._is_graph_network or\n            isinstance(keras_model, tf_keras.models.Sequential)):\n      raise ValueError('Object-based checkpoints are currently not supported '\n                       'with subclassed models.')\n    save_object_ckpt = True\n  elif checkpoint_format == 'saver':\n    save_object_ckpt = False\n  else:\n    raise ValueError(\n        'Checkpoint format must be one of \"checkpoint\" or \"saver\". Got {}'\n        .format(checkpoint_format))\n\n  if not hasattr(keras_model, 'optimizer') or not keras_model.optimizer:\n    raise ValueError('The given keras model has not been compiled yet. '\n                     'Please compile the model with `model.compile()` '\n                     'before calling `model_to_estimator()`.')\n  keras_model_fn = _create_keras_model_fn(\n      keras_model, custom_objects, save_object_ckpt, metric_names_map,\n      export_outputs)\n  if _any_weight_initialized(keras_model):\n    # Warn if config passed to estimator tries to update GPUOptions. If a\n    # session has already been created, the GPUOptions passed to the first\n    # session sticks.\n    if config.session_config.HasField('gpu_options'):\n      tf.compat.v1.logging.warn(\n          'The Keras backend session has already been set. '\n          'The _session_config passed to model_to_estimator will not be used.')\n  else:\n    # Pass the config into keras backend's default session.\n    sess = tf.compat.v1.Session(config=config.session_config)\n    tf_keras_v1.backend.set_session(sess)\n\n  warm_start_path = None\n  if keras_model._is_graph_network and config.is_chief:\n    warm_start_path = _save_first_checkpoint(keras_model, custom_objects,\n                                             config, save_object_ckpt)\n  elif keras_model.built:\n    tf.compat.v1.logging.warn(\n        'You are creating an Estimator from a Keras model manually '\n        'subclassed from `Model`, that was already called on some '\n        'inputs (and thus already had weights). We are currently '\n        'unable to preserve the model\\'s state (its weights) as '\n        'part of the estimator in this case. Be warned that the '\n        'estimator has been created using a freshly initialized '\n        'version of your model.\\n'\n        'Note that this doesn\\'t affect the state of the model '\n        'instance you passed as `keras_model` argument.')\n  if use_v2_estimator:\n    estimator_cls = estimator_lib.EstimatorV2\n  else:\n    estimator_cls = estimator_lib.Estimator\n\n  estimator = estimator_cls(\n      keras_model_fn, config=config, warm_start_from=warm_start_path)\n\n  return estimator\n\n\ndef _assert_valid_model(model, custom_objects=None):\n  is_subclass = (not model._is_graph_network and\n                 not isinstance(model, tf_keras.models.Sequential))\n  if is_subclass:\n    try:\n      custom_objects = custom_objects or {}\n      with tf_keras.utils.CustomObjectScope(custom_objects):\n        model.__class__.from_config(model.get_config())\n    except NotImplementedError:\n      raise ValueError(\n          'Subclassed `Model`s passed to `model_to_estimator` must '\n          'implement `Model.get_config` and `Model.from_config`.')\n\n\ndef standardize_sample_weights(x_weight, output_names):\n  \"\"\"Maps `sample_weight` or `class_weight` to model outputs.\n\n  Args:\n      x_weight: User-provided `sample_weight` or `class_weight` argument.\n      output_names: List of output names (strings) in the model.\n\n  Returns:\n      A list of `sample_weight` or `class_weight` where there are exactly\n          one element per model output.\n\n  Raises:\n      ValueError: In case of invalid user-provided argument.\n  \"\"\"\n  if x_weight is None or (isinstance(x_weight, (list, tuple)) and\n                          len(x_weight) == 0):  # pylint: disable=g-explicit-length-test\n    return [None for _ in output_names]\n  if len(output_names) == 1:\n    if isinstance(x_weight, (list, tuple)) and len(x_weight) == 1:\n      return x_weight\n    if isinstance(x_weight, dict) and output_names[0] in x_weight:\n      return [x_weight[output_names[0]]]\n    else:\n      return [x_weight]\n  if isinstance(x_weight, (list, tuple)):\n    if len(x_weight) != len(output_names):\n      raise ValueError('Provided `sample_weights` was a list of ' +\n                       str(len(x_weight)) + ' elements, but the model has ' +\n                       str(len(output_names)) + ' outputs. '\n                       'You should provide one `sample_weights`'\n                       'array per model output.')\n    return x_weight\n  if isinstance(x_weight, collections.abc.Mapping):\n    unknown = set(x_weight.keys()).difference(output_names)\n    if unknown:\n      raise ValueError('Unknown entries in sample_weights dictionary: {}. '\n                       'Only expected following keys: {}'.format(\n                           list(unknown), output_names))\n    x_weights = []\n    for name in output_names:\n      x_weights.append(x_weight.get(name))\n    return x_weights\n  else:\n    raise TypeError('The model has multiple outputs, so `sample_weights` '\n                    'should be either a list or a dict. '\n                    'Provided `sample_weights` type not understood: ' +\n                    str(x_weight))\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/keras_premade_model_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for keras premade model in model_to_estimator routines.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import keras_lib\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\n\n_RANDOM_SEED = 1337\n\n\ndef gen_input_fn(x, y=None, batch_size=32, num_epochs=10, shuffle=False):\n\n  def input_fn():\n    ds = tf.compat.v1.data.Dataset.from_tensor_slices((\n        x, y) if y is not None else x)\n    if shuffle:\n      ds = ds.shuffle(1000)\n    return ds.repeat(num_epochs).batch(batch_size)\n\n  return input_fn\n\n\ndef get_resource_for_simple_model():\n\n  input_name = 'input_1'\n  output_name = 'output_1'\n\n  np.random.seed(_RANDOM_SEED)\n  x_train = np.random.uniform(low=-5, high=5, size=(64, 2)).astype('f')\n  y_train = .3 * x_train[:, 0] + .2 * x_train[:, 1]\n  x_test = np.random.uniform(low=-5, high=5, size=(64, 2)).astype('f')\n  y_test = .3 * x_test[:, 0] + .2 * x_test[:, 1]\n\n  train_input_fn = gen_input_fn(\n      x=x_train, y=y_train, num_epochs=None, shuffle=False)\n\n  evaluate_input_fn = gen_input_fn(\n      x=randomize_io_type(x_test, input_name),\n      y=randomize_io_type(y_test, output_name),\n      num_epochs=1,\n      shuffle=False)\n\n  return (x_train, y_train), (x_test, y_test), train_input_fn, evaluate_input_fn\n\n\ndef randomize_io_type(array, name):\n  switch = np.random.random()\n  if switch > 0.5:\n    return array\n  else:\n    return {name: array}\n\n\nclass KerasPremadeModelTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._base_dir = os.path.join(self.get_temp_dir(), 'keras_estimator_test')\n    tf.compat.v1.gfile.MakeDirs(self._base_dir)\n    self._config = run_config_lib.RunConfig(\n        tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)\n    super(KerasPremadeModelTest, self).setUp()\n\n  def tearDown(self):\n    # Make sure nothing is stuck in limbo.\n    tf.compat.v1.summary.FileWriterCache.clear()\n    if os.path.isdir(self._base_dir):\n      tf.compat.v1.gfile.DeleteRecursively(self._base_dir)\n    tf_keras.backend.clear_session()\n    super(KerasPremadeModelTest, self).tearDown()\n\n  def test_train_premade_linear_model_with_dense_features(self):\n    vocab_list = ['alpha', 'beta', 'gamma']\n    vocab_val = [0.4, 0.6, 0.9]\n    data = np.random.choice(vocab_list, size=256)\n    y = np.zeros_like(data, dtype=np.float32)\n    for vocab, val in zip(vocab_list, vocab_val):\n      indices = np.where(data == vocab)\n      y[indices] = val + np.random.uniform(\n          low=-0.01, high=0.01, size=indices[0].shape)\n    cat_column = tf.feature_column.categorical_column_with_vocabulary_list(\n        key='symbol', vocabulary_list=vocab_list)\n    ind_column = tf.feature_column.indicator_column(cat_column)\n    keras_input = tf_keras.layers.Input(\n        name='symbol', shape=3, dtype=tf.dtypes.string)\n    feature_layer = tf_keras_v1.layers.DenseFeatures([ind_column])\n    h = feature_layer({'symbol': keras_input})\n    linear_model = tf_keras.experimental.LinearModel(units=1)\n    h = linear_model(h)\n\n    model = tf_keras.models.Model(inputs=keras_input, outputs=h)\n    opt = tf_keras.optimizers.legacy.SGD(0.1)\n    model.compile(opt, 'mse', ['mse'])\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'symbol': data}, y=y, num_epochs=20, shuffle=False)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'symbol': data}, y=y, num_epochs=20, shuffle=False)\n    est = keras_lib.model_to_estimator(\n        keras_model=model, config=self._config, checkpoint_format='saver')\n    before_eval_results = est.evaluate(input_fn=eval_input_fn, steps=1)\n    est.train(input_fn=train_input_fn, steps=30)\n    after_eval_results = est.evaluate(input_fn=eval_input_fn, steps=1)\n    self.assertLess(after_eval_results['loss'], before_eval_results['loss'])\n    self.assertLess(after_eval_results['loss'], 0.05)\n\n  def test_train_premade_linear_model(self):\n    (x_train,\n     y_train), _, train_inp_fn, eval_inp_fn = get_resource_for_simple_model()\n\n    linear_model = tf_keras.experimental.LinearModel(units=1)\n    opt = tf_keras.optimizers.legacy.SGD(0.1)\n    linear_model.compile(opt, 'mse', ['mse'])\n    linear_model.fit(x_train, y_train, epochs=10)\n\n    est = keras_lib.model_to_estimator(\n        keras_model=linear_model,\n        config=self._config,\n        checkpoint_format='saver')\n    before_eval_results = est.evaluate(input_fn=eval_inp_fn, steps=1)\n    est.train(input_fn=train_inp_fn, steps=500)\n    after_eval_results = est.evaluate(input_fn=eval_inp_fn, steps=1)\n    self.assertLess(after_eval_results['loss'], before_eval_results['loss'])\n    self.assertLess(after_eval_results['loss'], 0.1)\n\n  def test_train_premade_widedeep_model_with_feature_layers(self):\n    vocab_list = ['alpha', 'beta', 'gamma']\n    vocab_val = [0.4, 0.6, 0.9]\n    data = np.random.choice(vocab_list, size=256)\n    y = np.zeros_like(data, dtype=np.float32)\n    for vocab, val in zip(vocab_list, vocab_val):\n      indices = np.where(data == vocab)\n      y[indices] = val + np.random.uniform(\n          low=-0.01, high=0.01, size=indices[0].shape)\n    cat_column = tf.feature_column.categorical_column_with_vocabulary_list(\n        key='symbol', vocabulary_list=vocab_list)\n    ind_column = tf.feature_column.indicator_column(cat_column)\n    # TODO(tanzheny): use emb column for dense part once b/139667019 is fixed.\n    # emb_column = feature_column.embedding_column(cat_column, dimension=5)\n    keras_input = tf_keras.layers.Input(\n        name='symbol', shape=3, dtype=tf.dtypes.string)\n\n    # build linear part with feature layer.\n    linear_feature_layer = tf_keras_v1.layers.DenseFeatures([ind_column])\n    linear_model = tf_keras.experimental.LinearModel(\n        units=1, name='Linear', kernel_initializer='zeros')\n    combined_linear = tf_keras.models.Sequential([linear_feature_layer, linear_model])\n\n    # build dnn part with feature layer.\n    dnn_feature_layer = tf_keras_v1.layers.DenseFeatures([ind_column])\n    dense_layer = tf_keras.layers.Dense(\n        units=1, name='DNNDense', kernel_initializer='zeros')\n    combined_dnn = tf_keras.models.Sequential([dnn_feature_layer, dense_layer])\n\n    # build and compile wide deep.\n    wide_deep_model = tf_keras.experimental.WideDeepModel(combined_linear, combined_dnn)\n    wide_deep_model._set_inputs({'symbol': keras_input})\n    sgd_opt = tf_keras.optimizers.legacy.SGD(0.1)\n    adam_opt = tf_keras.optimizers.legacy.Adam(0.1)\n    wide_deep_model.compile([sgd_opt, adam_opt], 'mse', ['mse'])\n\n    # build estimator.\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'symbol': data}, y=y, num_epochs=20, shuffle=False)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'symbol': data}, y=y, num_epochs=20, shuffle=False)\n    est = keras_lib.model_to_estimator(\n        keras_model=wide_deep_model,\n        config=self._config,\n        checkpoint_format='saver')\n\n    before_eval_results = est.evaluate(input_fn=eval_input_fn, steps=1)\n    est.train(input_fn=train_input_fn, steps=20)\n    after_eval_results = est.evaluate(input_fn=eval_input_fn, steps=1)\n    self.assertLess(after_eval_results['loss'], before_eval_results['loss'])\n    self.assertLess(after_eval_results['loss'], 0.1)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/keras_test.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for training routines.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport json\nimport math\nimport os\nimport tempfile\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow.python.ops.parsing_ops import gen_parsing_ops\nfrom tensorflow.python.saved_model import path_helpers\nfrom tensorflow.python.saved_model.model_utils import export_output\nfrom tensorflow.python.training import saver as saver_lib\nfrom tensorflow_estimator.python.estimator import keras_lib\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\nfrom tensorflow_estimator.python.estimator.export import export_lib\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v2\n\ntry:\n  import h5py  # pylint:disable=g-import-not-at-top\nexcept ImportError:\n  h5py = None\n\n_RANDOM_SEED = 1337\n_TRAIN_SIZE = 200\n_INPUT_SIZE = (10,)\n_NUM_CLASS = 2\n\n_TMP_DIR = '/tmp'\n\n\ndef simple_sequential_model():\n  model = tf_keras.models.Sequential()\n  model.add(tf_keras.layers.Dense(16, activation='relu', input_shape=_INPUT_SIZE))\n  model.add(tf_keras.layers.Dropout(0.1))\n  model.add(tf_keras.layers.Dense(_NUM_CLASS, activation='softmax'))\n  return model\n\n\ndef simple_functional_model(activation='relu'):\n  a = tf_keras.layers.Input(shape=_INPUT_SIZE, name='input_layer')\n  b = tf_keras.layers.Dense(16, activation=activation)(a)\n  b = tf_keras.layers.Dropout(0.1)(b)\n  b = tf_keras.layers.Dense(_NUM_CLASS, activation='softmax')(b)\n  model = tf_keras.models.Model(inputs=[a], outputs=[b])\n  return model\n\n\ndef simple_subclassed_model():\n\n  class SimpleModel(tf_keras.models.Model):\n\n    def __init__(self):\n      super(SimpleModel, self).__init__()\n      self.dense1 = tf_keras.layers.Dense(16, activation='relu')\n      self.dp = tf_keras.layers.Dropout(0.1)\n      self.dense2 = tf_keras.layers.Dense(_NUM_CLASS, activation='softmax')\n\n    def call(self, inputs):\n      x = self.dense1(inputs)\n      x = self.dp(x)\n      return self.dense2(x)\n\n    def get_config(self):\n      return {}\n\n    @classmethod\n    def from_config(cls, config):\n      return cls()\n\n  return SimpleModel()\n\n\ndef gen_input_fn(x, y=None, batch_size=128, num_epochs=1, shuffle=False):\n\n  def input_fn():\n    ds = tf.compat.v1.data.Dataset.from_tensor_slices((\n        x, y) if y is not None else x)\n    if shuffle:\n      ds = ds.shuffle(1000)\n    return ds.repeat(num_epochs).batch(batch_size)\n\n  return input_fn\n\n\ndef get_multi_inputs_multi_outputs_data():\n  (a_train, c_train), (a_test, c_test) = get_test_data(\n      train_samples=_TRAIN_SIZE,\n      test_samples=50,\n      input_shape=(16,),\n      num_classes=3,\n      random_seed=_RANDOM_SEED)\n  (b_train, d_train), (b_test, d_test) = get_test_data(\n      train_samples=_TRAIN_SIZE,\n      test_samples=50,\n      input_shape=(16,),\n      num_classes=2,\n      random_seed=_RANDOM_SEED)\n  (m_train, _), (m_test, _) = get_test_data(\n      train_samples=_TRAIN_SIZE,\n      test_samples=50,\n      input_shape=(8,),\n      num_classes=2,\n      random_seed=_RANDOM_SEED)\n\n  c_train = tf_keras.utils.to_categorical(c_train)\n  c_test = tf_keras.utils.to_categorical(c_test)\n  d_train = tf_keras.utils.to_categorical(d_train)\n  d_test = tf_keras.utils.to_categorical(d_test)\n\n  train_data = {\n      'input_a': a_train,\n      'input_b': b_train,\n      'input_m': m_train,\n      'output_c': c_train,\n      'output_d': d_train\n  }\n  test_data = {\n      'input_a': a_test,\n      'input_b': b_test,\n      'input_m': m_test,\n      'output_c': c_test,\n      'output_d': d_test\n  }\n\n  return (train_data, test_data)\n\n\ndef get_resource_for_simple_model(\n    model_type='sequential',\n    is_evaluate=False,\n):\n  if model_type == 'sequential':\n    model = simple_sequential_model()\n    model.build()\n  elif model_type == 'subclass':\n    model = simple_subclassed_model()\n  else:\n    assert model_type == 'functional'\n    model = simple_functional_model()\n\n  if model_type == 'subclass':\n    input_name = 'input_1'\n    output_name = 'output_1'\n  else:\n    input_name = model.input_names[0]\n    output_name = model.output_names[0]\n\n  np.random.seed(_RANDOM_SEED)\n  (x_train, y_train), (x_test, y_test) = get_test_data(\n      train_samples=_TRAIN_SIZE,\n      test_samples=50,\n      input_shape=_INPUT_SIZE,\n      num_classes=_NUM_CLASS)\n  y_train = tf_keras.utils.to_categorical(y_train)\n  y_test = tf_keras.utils.to_categorical(y_test)\n\n  train_input_fn = gen_input_fn(\n      x=randomize_io_type(x_train, input_name),\n      y=randomize_io_type(y_train, output_name),\n      shuffle=False,\n      num_epochs=None,\n      batch_size=16)\n\n  evaluate_input_fn = gen_input_fn(\n      x=randomize_io_type(x_test, input_name),\n      y=randomize_io_type(y_test, output_name),\n      num_epochs=1,\n      shuffle=False)\n\n  predict_input_fn = gen_input_fn(\n      x=randomize_io_type(x_test, input_name), num_epochs=1, shuffle=False)\n\n  inference_input_fn = evaluate_input_fn if is_evaluate else predict_input_fn\n\n  return model, (x_train, y_train), (x_test,\n                                     y_test), train_input_fn, inference_input_fn\n\n\ndef randomize_io_type(array, name):\n  switch = np.random.random()\n  if switch > 0.5:\n    return array\n  else:\n    return {name: array}\n\n\ndef multi_inputs_multi_outputs_model():\n  input_a = tf_keras.layers.Input(shape=(16,), name='input_a')\n  input_b = tf_keras.layers.Input(shape=(16,), name='input_b')\n  input_m = tf_keras.layers.Input(shape=(8,), dtype='string', name='input_m')\n  dense = tf_keras.layers.Dense(8, name='dense_1')\n\n  interm_a = dense(input_a)\n  # Read m\n  interm_m = tf_keras.layers.Lambda(gen_parsing_ops.string_to_number)(input_m)\n  interm_s = tf_keras.layers.Lambda(lambda k: k[0] * k[1])([interm_m, interm_a])\n  interm_b = dense(input_b)\n  merged = tf_keras.layers.concatenate([interm_s, interm_b], name='merge')\n  output_c = tf_keras.layers.Dense(3, activation='softmax', name='dense_2')(\n      merged)\n  output_d = tf_keras.layers.Dense(2, activation='softmax', name='dense_3')(\n      merged)\n  model = tf_keras.models.Model(\n      inputs=[input_a, input_b, input_m], outputs=[output_c, output_d])\n  model.compile(\n      loss='categorical_crossentropy',\n      optimizer='rmsprop',\n      metrics={\n          'dense_2': 'categorical_accuracy',\n          'dense_3': 'categorical_accuracy'\n      })\n  return model\n\n\nclass MyHook(tf.compat.v1.train.SessionRunHook):\n\n  def begin(self):\n    _ = tf.compat.v1.get_variable('temp', [1])\n\n\nclass TestKerasEstimator(tf.test.TestCase, parameterized.TestCase):\n\n  def setUp(self):\n    self._base_dir = os.path.join(self.get_temp_dir(), 'keras_estimator_test')\n    tf.compat.v1.gfile.MakeDirs(self._base_dir)\n    self._config = run_config_lib.RunConfig(\n        tf_random_seed=_RANDOM_SEED, model_dir=self._base_dir)\n    super(TestKerasEstimator, self).setUp()\n\n  def tearDown(self):\n    # Make sure nothing is stuck in limbo.\n    tf.compat.v1.summary.FileWriterCache.clear()\n    if os.path.isdir(self._base_dir):\n      tf.compat.v1.gfile.DeleteRecursively(self._base_dir)\n    tf_keras.backend.clear_session()\n    super(TestKerasEstimator, self).tearDown()\n\n  @parameterized.named_parameters(\n      dict(\n          testcase_name='functional',\n          model_type='functional',\n          checkpoint_format='saver'),\n      dict(\n          testcase_name='sequential',\n          model_type='sequential',\n          checkpoint_format='saver'),\n      dict(\n          testcase_name='subclass',\n          model_type='subclass',\n          optimizer='tf_rmsprop',\n          checkpoint_format='saver'),\n      dict(\n          testcase_name='functional_object_ckpt',\n          model_type='functional',\n          checkpoint_format='checkpoint'),\n      dict(\n          testcase_name='sequential_object_ckpt_w_fit',\n          model_type='sequential',\n          checkpoint_format='checkpoint',\n          fit_before_export=True,\n          optimizer='tf_rmsprop'),\n      dict(\n          testcase_name='functional_w_fit',\n          model_type='functional',\n          fit_before_export=True,\n          optimizer='tf_rmsprop',\n          checkpoint_format='saver'),\n      dict(\n          testcase_name='subclass_w_fit',\n          model_type='subclass',\n          fit_before_export=True,\n          optimizer='tf_rmsprop',\n          checkpoint_format='saver'),\n      # b/109935364\n      dict(\n          testcase_name='hooks',\n          model_type='subclass',\n          hook=MyHook,\n          optimizer='tf_rmsprop',\n          checkpoint_format='saver'),\n      dict(\n          testcase_name='hooks_and_fit',\n          model_type='subclass',\n          hook=MyHook,\n          fit_before_export=True,\n          optimizer='tf_rmsprop',\n          checkpoint_format='saver'),\n      dict(\n          testcase_name='tf_optimizer',\n          model_type='subclass',\n          hook=MyHook,\n          optimizer='tf_rmsprop',\n          fit_before_export=True,\n          checkpoint_format='saver'))\n  def test_train_keras_estimator(self,\n                                 model_type,\n                                 checkpoint_format=None,\n                                 fit_before_export=False,\n                                 optimizer='rmsprop',\n                                 hook=None):\n    hooks = [hook()] if hook else None\n    tf_optimizer = False\n    if optimizer == 'tf_rmsprop':\n      tf_optimizer = True\n      optimizer = tf.compat.v1.train.RMSPropOptimizer(1e-3)\n\n    keras_model, (x_train, y_train), (_, _), train_input_fn, eval_input_fn = (\n        get_resource_for_simple_model(model_type=model_type, is_evaluate=True))\n    keras_model.compile(\n        optimizer=optimizer,\n        loss='categorical_crossentropy',\n        metrics=['accuracy'])\n    if fit_before_export:\n      keras_model.fit(x_train, y_train, epochs=1)\n\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=keras_model,\n        config=self._config,\n        checkpoint_format=checkpoint_format)\n\n    est_keras.train(\n        input_fn=train_input_fn, steps=_TRAIN_SIZE / 16, hooks=hooks)\n    before_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)\n    est_keras.train(\n        input_fn=train_input_fn, steps=_TRAIN_SIZE / 16, hooks=hooks)\n    after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)\n    self.assertLess(after_eval_results['loss'], before_eval_results['loss'])\n\n    if checkpoint_format == 'object' and tf_optimizer:\n      latest_checkpoint = tf.train.latest_checkpoint(est_keras.model_dir)\n      keras_model.load_weights(latest_checkpoint)\n\n  def test_train_with_dense_features(self):\n    feature_dict = {\n        'sex': np.int64([1, 1, 1, 1, 0]),\n        'cp': np.int64([0, 3, 3, 2, 1]),\n        'slope': np.int64([3, 2, 0, 3, 1]),\n    }\n    label = np.int64([0, 1, 0, 0, 0])\n    train_input_fn = numpy_io.numpy_input_fn(\n        x=feature_dict, y=label, num_epochs=1, shuffle=False)\n    feature_columns = list()\n    input_features = dict()\n    for feature_name, data_array in feature_dict.items():\n      feature_columns.append(\n          tf.feature_column.indicator_column(\n              tf.feature_column.categorical_column_with_identity(\n                  key=feature_name,\n                  num_buckets=np.size(np.unique(data_array)))))\n      input_features[feature_name] = tf_keras.layers.Input(\n          name=feature_name,\n          shape=(np.size(np.unique(data_array)),),\n          dtype=tf.dtypes.int64)\n\n    x = tf_keras_v1.layers.DenseFeatures(feature_columns)(input_features)\n    x = tf_keras.layers.Dense(16, activation='relu')(x)\n    logits = tf_keras.layers.Dense(1, activation='linear')(x)\n    model = tf_keras.models.Model(inputs=input_features, outputs=logits)\n\n    model.compile(\n        optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])\n    estimator_model = keras_lib.model_to_estimator(keras_model=model)\n    estimator_model.train(input_fn=train_input_fn, steps=5)\n\n  # TODO(b/139845232): Enable after TF2 nightly's start.\n  def DISABLED_test_train_with_dense_features_embedding(self):\n    feature_dict = {\n        'sex': np.int64([1, 1, 1, 1, 0]),\n        'cp': np.int64([0, 3, 3, 2, 1]),\n        'slope': np.int64([3, 2, 0, 3, 1]),\n    }\n    label = np.int64([0, 1, 0, 0, 0])\n    train_input_fn = numpy_io.numpy_input_fn(\n        x=feature_dict, y=label, num_epochs=1, shuffle=False)\n    feature_columns = list()\n    input_features = dict()\n    for feature_name, data_array in feature_dict.items():\n      feature_columns.append(\n          tf.feature_column.embedding_column(\n              tf.feature_column.categorical_column_with_identity(\n                  key=feature_name, num_buckets=np.size(np.unique(data_array))),\n              dimension=3))\n      input_features[feature_name] = tf_keras.layers.Input(\n          name=feature_name,\n          shape=(np.size(np.unique(data_array)),),\n          dtype=tf.dtypes.int64)\n\n    df = tf_keras_v1.layers.DenseFeatures(feature_columns)\n    x = df(input_features)\n    x = tf_keras.layers.Dense(16, activation='relu')(x)\n    logits = tf_keras.layers.Dense(1, activation='linear')(x)\n    model = tf_keras.models.Model(inputs=input_features, outputs=logits)\n\n    model.compile(\n        optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])\n    estimator_model = keras_lib.model_to_estimator(keras_model=model)\n    estimator_model.train(input_fn=train_input_fn, steps=5)\n    # We assert that we find the embedding_weights variables in the dependencies\n    # for the DenseFeatures layer.\n    dependency_names = list(df._trackable_children())\n    self.assertNotIn('embedding_weights', dependency_names)\n    self.assertIn('cp_embedding/embedding_weights', dependency_names)\n    self.assertIn('sex_embedding/embedding_weights', dependency_names)\n    self.assertIn('slope_embedding/embedding_weights', dependency_names)\n\n  # TODO(b/139845232): Enable after TF2 nightly's start.\n  def DISABLED_test_train_with_dense_features_v2(self):\n    feature_dict = {\n        'sex': np.int64([1, 1, 1, 1, 0]),\n        'cp': np.int64([0, 3, 3, 2, 1]),\n        'slope': np.int64([3, 2, 0, 3, 1]),\n    }\n    label = np.int64([0, 1, 0, 0, 0])\n    train_input_fn = numpy_io.numpy_input_fn(\n        x=feature_dict, y=label, num_epochs=1, shuffle=False)\n    feature_columns = list()\n    input_features = dict()\n    for feature_name, data_array in feature_dict.items():\n      feature_columns.append(\n          tf.feature_column.embedding_column(\n              tf.feature_column.categorical_column_with_identity(\n                  key=feature_name, num_buckets=np.size(np.unique(data_array))),\n              dimension=3))\n      input_features[feature_name] = tf_keras.layers.Input(\n          name=feature_name,\n          shape=(np.size(np.unique(data_array)),),\n          dtype=tf.dtypes.int64)\n\n    df = tf_keras_v2.layers.DenseFeatures(feature_columns)\n    x = df(input_features)\n    x = tf_keras.layers.Dense(16, activation='relu')(x)\n    logits = tf_keras.layers.Dense(1, activation='linear')(x)\n    model = tf_keras.models.Model(inputs=input_features, outputs=logits)\n\n    model.compile(\n        optimizer='rmsprop', loss='binary_crossentropy', metrics=['accuracy'])\n    estimator_model = keras_lib.model_to_estimator(keras_model=model)\n    estimator_model.train(input_fn=train_input_fn, steps=5)\n    # We assert that we find the embedding_weights variables in the dependencies\n    # for the DenseFeatures layer.\n    dependency_names = list(df._trackable_children())\n    self.assertNotIn('embedding_weights', dependency_names)\n    self.assertIn('cp_embedding/embedding_weights', dependency_names)\n    self.assertIn('sex_embedding/embedding_weights', dependency_names)\n    self.assertIn('slope_embedding/embedding_weights', dependency_names)\n\n  def test_evaluate(self):\n    keras_model, (x_train, y_train), (\n        x_test, y_test), _, eval_input_fn = get_resource_for_simple_model(\n            model_type='functional', is_evaluate=True)\n\n    metrics = [\n        'binary_accuracy', 'binary_crossentropy', 'categorical_accuracy',\n        'categorical_crossentropy', 'cosine_proximity', 'hinge',\n        'kullback_leibler_divergence', 'mean_absolute_error',\n        'mean_absolute_percentage_error', 'mean_squared_error',\n        'mean_squared_logarithmic_error', 'poisson', 'squared_hinge',\n        'top_k_categorical_accuracy'\n    ]\n    keras_model.compile(\n        loss='categorical_crossentropy', optimizer='adam', metrics=metrics)\n    keras_model.fit(x_train, y_train, epochs=1)\n    keras_eval = keras_model.evaluate(x_test, y_test, batch_size=32)\n\n    keras_est = keras_lib.model_to_estimator(\n        keras_model=keras_model, config=self._config)\n    est_eval = keras_est.evaluate(input_fn=eval_input_fn)\n\n    metrics = ['loss'] + metrics\n\n    # Check loss and all metrics match between keras and estimator.\n    def shift(val):\n      if val == 0:\n        return 0\n      else:\n        return val / 10**int(math.log10(abs(val)))\n\n    for i, metric_name in enumerate(metrics):\n      if i == 0:\n        continue  # TODO(b/148461691): Investigate 1% diff in loss.\n      self.assertAlmostEqual(\n          shift(keras_eval[i]),\n          shift(est_eval[metric_name]),\n          places=4,\n          msg='%s mismatch, keras model: %s, estimator: %s' %\n          (metric_name, keras_eval[i], est_eval[metric_name]))\n\n  def test_evaluate_multi_io_model(self):\n    input_a = tf_keras.layers.Input(shape=(16,), name='input_a')\n    input_b = tf_keras.layers.Input(shape=(16,), name='input_b')\n    dense = tf_keras.layers.Dense(8, name='dense_1')\n    interm_a = dense(input_a)\n    interm_b = dense(input_b)\n    merged = tf_keras.layers.concatenate([interm_a, interm_b], name='merge')\n    output_a = tf_keras.layers.Dense(\n        3, activation='softmax', name='dense_2')(\n            merged)\n    output_b = tf_keras.layers.Dense(\n        2, activation='softmax', name='dense_3')(\n            merged)\n    keras_model = tf_keras.models.Model(\n        inputs=[input_a, input_b], outputs=[output_a, output_b])\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer='rmsprop',\n        metrics={\n            'dense_2': 'categorical_accuracy',\n            'dense_3': 'categorical_accuracy'\n        })\n\n    np.random.seed(_RANDOM_SEED)\n    (x_train_1, y_train_1), (x_test_1, y_test_1) = get_test_data(\n        train_samples=_TRAIN_SIZE,\n        test_samples=50,\n        input_shape=(16,),\n        num_classes=3)\n    (x_train_2, y_train_2), (x_test_2, y_test_2) = get_test_data(\n        train_samples=_TRAIN_SIZE,\n        test_samples=50,\n        input_shape=(16,),\n        num_classes=2)\n    y_train_1 = tf_keras.utils.to_categorical(y_train_1)\n    y_test_1 = tf_keras.utils.to_categorical(y_test_1)\n    y_train_2 = tf_keras.utils.to_categorical(y_train_2)\n    y_test_2 = tf_keras.utils.to_categorical(y_test_2)\n\n    keras_model.fit((x_train_1, x_train_2), (y_train_1, y_train_2), epochs=1)\n    keras_eval = keras_model.evaluate((x_test_1, x_test_2),\n                                      (y_test_1, y_test_2),\n                                      batch_size=32)\n\n    def input_fn():\n      ds = tf.compat.v1.data.Dataset.from_tensor_slices(\n          ((x_test_1, x_test_2), (y_test_1, y_test_2)))\n      return ds.batch(128)\n\n    keras_est = keras_lib.model_to_estimator(\n        keras_model=keras_model, config=self._config)\n    est_eval = keras_est.evaluate(input_fn=input_fn)\n\n    def verify_correctness(metric_names):\n      for i, metric_name in enumerate(metric_names):\n        if i < 3:  # TODO(b/148461691): Investigate 1% diff in loss.\n          continue\n        self.assertAlmostEqual(\n            keras_eval[i],\n            est_eval[metric_name],\n            places=4,\n            msg='%s mismatch, keras model: %s, estimator: %s' %\n            (metric_name, keras_eval[i], est_eval[metric_name]))\n\n    verify_correctness([\n        'loss', 'dense_2_loss', 'dense_3_loss', 'dense_2_categorical_accuracy',\n        'dense_3_categorical_accuracy'\n    ])\n\n    metric_names_map = {\n        'dense_2_categorical_accuracy': 'acc_1',\n        'dense_3_categorical_accuracy': 'acc_2',\n    }\n    keras_est = keras_lib.model_to_estimator(\n        keras_model=keras_model,\n        config=self._config,\n        metric_names_map=metric_names_map)\n    est_eval = keras_est.evaluate(input_fn=input_fn)\n    verify_correctness(\n        ['loss', 'dense_2_loss', 'dense_3_loss', 'acc_1', 'acc_2'])\n\n  def test_invalid_metric_names_map(self):\n    keras_model, (_, _), (_,\n                          _), _, eval_input_fn = get_resource_for_simple_model(\n                              model_type='functional', is_evaluate=True)\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer='adam',\n        metrics=['binary_accuracy'])\n\n    keras_est = keras_lib.model_to_estimator(\n        keras_model=keras_model,\n        config=self._config,\n        metric_names_map={'binary_acc': ''})\n    with self.assertRaisesRegexp(ValueError,\n                                 r'Invalid `metric_names_map`.*do not exist'):\n      keras_est.evaluate(input_fn=eval_input_fn)\n    keras_est = keras_lib.model_to_estimator(\n        keras_model=keras_model,\n        config=self._config,\n        metric_names_map={\n            'binary_accuracy': 'acc',\n            'abcde': ''\n        })\n    with self.assertRaisesRegexp(\n        ValueError, r'Invalid `metric_names_map`.*unexpected keys'):\n      keras_est.evaluate(input_fn=eval_input_fn)\n\n  def test_predict(self):\n    # Check that predict on a pretrained model yield the same result.\n    keras_model, (x_train, y_train), (\n        x_test, _), _, pred_input_fn = get_resource_for_simple_model(\n            model_type='sequential', is_evaluate=False)\n\n    keras_model.compile(\n        loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n\n  def test_multi_inputs_multi_outputs_with_input_fn_as_dict(self):\n    train_data, test_data = get_multi_inputs_multi_outputs_data()\n\n    def train_input_fn():\n      input_dict = {\n          'input_a': train_data['input_a'],\n          'input_b': train_data['input_b'],\n          'input_m': train_data['input_m'].astype(str)\n      }\n      output_dict = {\n          'dense_2': train_data['output_c'],\n          'dense_3': train_data['output_d']\n      }\n      return input_dict, output_dict\n\n    def eval_input_fn():\n      input_dict = {\n          'input_a': test_data['input_a'],\n          'input_b': test_data['input_b'],\n          'input_m': test_data['input_m'].astype(str)\n      }\n      output_dict = {\n          'dense_2': test_data['output_c'],\n          'dense_3': test_data['output_d']\n      }\n      return input_dict, output_dict\n\n    def pred_input_fn():\n      input_dict = {\n          'input_a': test_data['input_a'],\n          'input_b': test_data['input_b'],\n          'input_m': test_data['input_m'].astype(str)\n      }\n      return input_dict\n\n    self.do_test_multi_inputs_multi_outputs_with_input_fn(\n        train_input_fn, eval_input_fn, pred_input_fn)\n\n  def test_multi_inputs_multi_outputs_with_input_fn_as_list(self):\n    train_data, test_data = get_multi_inputs_multi_outputs_data()\n\n    def train_input_fn():\n      input_list = [\n          train_data['input_a'], train_data['input_b'],\n          train_data['input_m'].astype(str)\n      ]\n      output_list = [train_data['output_c'], train_data['output_d']]\n      return input_list, output_list\n\n    def eval_input_fn():\n      input_list = [\n          test_data['input_a'], test_data['input_b'],\n          test_data['input_m'].astype(str)\n      ]\n      output_list = [test_data['output_c'], test_data['output_d']]\n      return input_list, output_list\n\n    def pred_input_fn():\n      input_list = [\n          test_data['input_a'], test_data['input_b'],\n          test_data['input_m'].astype(str)\n      ]\n      return input_list\n\n    self.do_test_multi_inputs_multi_outputs_with_input_fn(\n        train_input_fn, eval_input_fn, pred_input_fn)\n\n  def do_test_multi_inputs_multi_outputs_with_input_fn(self, train_input_fn,\n                                                       eval_input_fn,\n                                                       pred_input_fn):\n    model = multi_inputs_multi_outputs_model()\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=model, config=self._config)\n    baseline_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)\n    est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)\n    eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)\n    self.assertLess(eval_results['loss'], baseline_eval_results['loss'])\n    est_keras.predict(input_fn=pred_input_fn)\n\n  def test_init_from_file(self):\n    if h5py is None:\n      return  # Skip test if models cannot be saved.\n\n    keras_model, (x_train, y_train), (\n        x_test, _), _, pred_input_fn = get_resource_for_simple_model(\n            model_type='functional', is_evaluate=False)\n\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer='rmsprop',\n        metrics=['categorical_accuracy'])\n    keras_model.fit(x_train, y_train, epochs=1)\n    keras_pred = [np.argmax(y) for y in keras_model.predict(x_test)]\n    fname = os.path.join(self._base_dir, 'keras_model.h5')\n    tf_keras.models.save_model(keras_model, fname)\n\n    keras_est = keras_lib.model_to_estimator(\n        keras_model_path=fname, config=self._config)\n    est_pred = [\n        np.argmax(y[keras_model.output_names[0]])\n        for y in keras_est.predict(input_fn=pred_input_fn)\n    ]\n    self.assertAllEqual(est_pred, keras_pred)\n\n  def test_keras_model_init_error(self):\n    with self.assertRaisesRegexp(ValueError, 'Either'):\n      keras_lib.model_to_estimator()\n\n    keras_model = simple_sequential_model()\n    with self.assertRaisesRegexp(ValueError, 'not both'):\n      keras_lib.model_to_estimator(\n          keras_model=keras_model,\n          keras_model_path=tempfile.mkdtemp(dir=self._base_dir))\n\n    keras_model = simple_sequential_model()\n    with self.assertRaisesRegexp(ValueError, 'compiled'):\n      keras_lib.model_to_estimator(keras_model=keras_model)\n\n  def test_invalid_ionames_error(self):\n    (x_train, y_train), (_, _) = get_test_data(\n        train_samples=_TRAIN_SIZE,\n        test_samples=100,\n        input_shape=(10,),\n        num_classes=2)\n    y_train = tf_keras.utils.to_categorical(y_train)\n\n    def invald_input_name_input_fn():\n      input_dict = {'invalid_input_name': x_train}\n      return input_dict, y_train\n\n    def invald_output_name_input_fn():\n      input_dict = {'input_layer': x_train}\n      output_dict = {'invalid_output_name': y_train}\n      return input_dict, output_dict\n\n    model = simple_functional_model()\n    model.compile(\n        loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=model, config=self._config)\n\n    regexp_pattern = r'{} keys:(\\s|.)*{}(\\s|.)*Missed keys:(\\s|.)*{}'\n\n    with self.assertRaisesRegexp(\n        keras_lib.FormattedKeyError,\n        regexp_pattern.format('features', 'invalid_input_name', 'input_layer')):\n      est_keras.train(input_fn=invald_input_name_input_fn, steps=100)\n\n    with self.assertRaisesRegexp(\n        keras_lib.FormattedKeyError,\n        regexp_pattern.format('labels', 'invalid_output_name', 'dense_1')):\n      est_keras.train(input_fn=invald_output_name_input_fn, steps=100)\n\n  def test_custom_objects(self):\n\n    def custom_relu(x):\n      return tf_keras.backend.relu(x, max_value=6)\n\n    keras_model = simple_functional_model(activation=custom_relu)\n    keras_model.compile(loss='categorical_crossentropy', optimizer='adam')\n    custom_objects = {'custom_relu': custom_relu}\n\n    (x_train, y_train), _ = get_test_data(\n        train_samples=_TRAIN_SIZE,\n        test_samples=50,\n        input_shape=(10,),\n        num_classes=2)\n    y_train = tf_keras.utils.to_categorical(y_train, 2)\n    input_name = keras_model.input_names[0]\n    output_name = keras_model.output_names[0]\n    train_input_fn = gen_input_fn(\n        x=randomize_io_type(x_train, input_name),\n        y=randomize_io_type(y_train, output_name),\n        shuffle=False,\n        num_epochs=None,\n        batch_size=16)\n    with self.assertRaisesRegex(Exception, 'custom_relu'):\n      # Could be either a TypeError or ValueError\n      est = keras_lib.model_to_estimator(\n          keras_model=keras_model,\n          model_dir=tempfile.mkdtemp(dir=self._base_dir))\n      est.train(input_fn=train_input_fn, steps=1)\n\n    est = keras_lib.model_to_estimator(\n        keras_model=keras_model,\n        model_dir=tempfile.mkdtemp(dir=self._base_dir),\n        custom_objects=custom_objects)\n    est.train(input_fn=train_input_fn, steps=1)\n\n  def test_tf_config(self):\n    keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer='rmsprop',\n        metrics=['mse', tf_keras.metrics.CategoricalAccuracy()])\n\n    tf_config = json.dumps({\n        'cluster': {\n            run_config_lib.TaskType.PS: ['localhost:1234'],\n            run_config_lib.TaskType.WORKER: ['localhost:1236'],\n            run_config_lib.TaskType.MASTER: ['localhost:1238']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n            'index': 0\n        }\n    })\n    with tf.compat.v1.test.mock.patch.dict('os.environ',\n                                           {'TF_CONFIG': tf_config}):\n      keras_lib.model_to_estimator(\n          keras_model=keras_model,\n          model_dir=tempfile.mkdtemp(dir=self._base_dir))\n\n  def test_gpu_config(self):\n    with tf.Graph().as_default():\n      keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()\n      keras_model.compile(\n          loss='categorical_crossentropy',\n          optimizer='rmsprop',\n          metrics=['mse', tf_keras.metrics.CategoricalAccuracy()])\n\n      gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.3)\n      sess_config = tf.compat.v1.ConfigProto(gpu_options=gpu_options)\n      self._config._session_config = sess_config\n      keras_lib.model_to_estimator(keras_model=keras_model, config=self._config)\n      self.assertEqual(\n          tf_keras_v1.backend.get_session(\n          )._config.gpu_options.per_process_gpu_memory_fraction,\n          gpu_options.per_process_gpu_memory_fraction)\n\n  def test_with_empty_config(self):\n    keras_model, _, _, _, _ = get_resource_for_simple_model(\n        model_type='sequential', is_evaluate=True)\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer='rmsprop',\n        metrics=['mse', tf_keras.metrics.CategoricalAccuracy()])\n\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=keras_model,\n        model_dir=self._base_dir,\n        config=run_config_lib.RunConfig())\n    self.assertEqual(run_config_lib.get_default_session_config(),\n                     est_keras._session_config)\n    self.assertEqual(est_keras._session_config,\n                     est_keras._config.session_config)\n    self.assertEqual(self._base_dir, est_keras._config.model_dir)\n    self.assertEqual(self._base_dir, est_keras._model_dir)\n\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=keras_model, model_dir=self._base_dir, config=None)\n    self.assertEqual(run_config_lib.get_default_session_config(),\n                     est_keras._session_config)\n    self.assertEqual(est_keras._session_config,\n                     est_keras._config.session_config)\n    self.assertEqual(self._base_dir, est_keras._config.model_dir)\n    self.assertEqual(self._base_dir, est_keras._model_dir)\n\n  def test_with_empty_config_and_empty_model_dir(self):\n    keras_model, _, _, _, _ = get_resource_for_simple_model(\n        model_type='sequential', is_evaluate=True)\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer='rmsprop',\n        metrics=['mse', tf_keras.metrics.CategoricalAccuracy()])\n\n    with tf.compat.v1.test.mock.patch.object(\n        tempfile, 'mkdtemp', return_value=_TMP_DIR):\n      est_keras = keras_lib.model_to_estimator(\n          keras_model=keras_model, config=run_config_lib.RunConfig())\n      self.assertEqual(est_keras._model_dir, _TMP_DIR)\n\n  def test_with_conflicting_model_dir_and_config(self):\n    keras_model, _, _, _, _ = get_resource_for_simple_model(\n        model_type='sequential', is_evaluate=True)\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer='rmsprop',\n        metrics=['mse', tf_keras.metrics.CategoricalAccuracy()])\n\n    with self.assertRaisesRegexp(\n        ValueError, '`model_dir` are set both in '\n        'constructor and `RunConfig`'):\n      keras_lib.model_to_estimator(\n          keras_model=keras_model,\n          model_dir=self._base_dir,\n          config=run_config_lib.RunConfig(model_dir=_TMP_DIR))\n\n  def test_pretrained_weights(self):\n    keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer=tf.compat.v1.train.RMSPropOptimizer(1e-3),\n        metrics=['mse', tf_keras.metrics.CategoricalAccuracy()])\n    keras_model.train_on_batch(\n        np.random.random((10,) + _INPUT_SIZE), np.random.random(\n            (10, _NUM_CLASS)))\n    weights = keras_model.get_weights()\n    keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model()\n    keras_model.set_weights(weights)\n\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer='sgd',\n        metrics=['mse', tf_keras.metrics.CategoricalAccuracy()])\n    keras_lib.model_to_estimator(keras_model=keras_model, config=self._config)\n\n  def assert_increasing_global_step(self, optimizer):\n    keras_model, _, _, train_input_fn, _ = get_resource_for_simple_model(\n        model_type='sequential', is_evaluate=True)\n    keras_model.compile(\n        loss='categorical_crossentropy',\n        optimizer=optimizer,\n        metrics=['mse', tf_keras.metrics.CategoricalAccuracy()])\n    with self.cached_session() as sess:\n      keras_model_fn = keras_lib._create_keras_model_fn(keras_model)\n      global_step = tf.compat.v1.train.create_global_step()\n      features, labels = train_input_fn().make_one_shot_iterator().get_next()\n      spec = keras_model_fn(features, labels, mode=ModeKeys.TRAIN)\n\n      sess.run(tf.compat.v1.initializers.global_variables())\n      sess.run(tf.compat.v1.initializers.local_variables())\n\n      self.assertEqual(global_step.eval(), 0)  # Sanity check\n      sess.run(spec.train_op)\n      self.assertEqual(global_step.eval(), 1)\n\n  @test_util.run_v1_only('training_util.create_global_step is v1 only.')\n  def test_model_fn_increments_global_step_tf_optimizer(self):\n    self.assert_increasing_global_step(\n        tf.compat.v1.train.RMSPropOptimizer(1e-3))\n\n  @test_util.run_v1_only('training_util.create_global_step is v1 only.')\n  def test_model_fn_increments_global_step_keras_optimizer(self):\n    self.assert_increasing_global_step('rmsprop')\n\n  @parameterized.named_parameters(\n      dict(testcase_name='object_ckpt', checkpoint_format='checkpoint'),\n      dict(testcase_name='name_ckpt', checkpoint_format='saver'))\n  def test_export_keras_estimator(self, checkpoint_format):\n    keras_model, (x_train, y_train), (\n        _, _), train_input_fn, _ = get_resource_for_simple_model(\n            model_type='sequential', is_evaluate=False)\n\n    keras_model.compile(\n        loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])\n    keras_model.fit(x_train, y_train, epochs=1)\n    bias_value = tf_keras.backend.get_value(keras_model.layers[0].bias)\n\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=keras_model,\n        model_dir=tempfile.mkdtemp(dir=self._base_dir),\n        checkpoint_format=checkpoint_format)\n\n    def serving_input_receiver_fn():\n      feature_spec = {\n          'dense_input': tf.io.FixedLenFeature([1], dtype=tf.dtypes.float32)\n      }\n      return export_lib.build_parsing_serving_input_receiver_fn(feature_spec)\n\n    # Try immediately exporting, testing that (1) exported values are the same,\n    # and (2) estimator can be exported without saving a checkpoint into the\n    # model directory.\n    saved_model_dir = est_keras.export_saved_model(\n        tempfile.mkdtemp(dir=self._base_dir), serving_input_receiver_fn())\n    variables_path = path_helpers.get_variables_path(saved_model_dir)\n\n    variable_name = 'dense/bias'\n    if checkpoint_format == 'checkpoint':\n      names_to_keys = saver_lib.object_graph_key_mapping(variables_path)\n      variable_name = names_to_keys[variable_name]\n\n    self.assertAllClose(bias_value,\n                        tf.train.load_variable(variables_path, variable_name))\n\n    # Export the estimator after training a bit.\n    est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)\n    saved_model_dir = est_keras.export_saved_model(\n        tempfile.mkdtemp(dir=self._base_dir), serving_input_receiver_fn())\n    variables_path = path_helpers.get_variables_path(saved_model_dir)\n    self.assertNotAllClose(\n        bias_value, tf.train.load_variable(variables_path, variable_name))\n\n  @parameterized.named_parameters(\n      dict(testcase_name='object_ckpt', checkpoint_format='checkpoint'),\n      dict(testcase_name='name_ckpt', checkpoint_format='saver'))\n  def test_export_keras_estimator_custom_signatures(self, checkpoint_format):\n    inputs_a = np.random.random((320, 1))\n    inputs_b = np.random.random((320, 1))\n    outputs_c = np.random.random((320, 1))\n    outputs_d = np.random.random((320, 1))\n\n    dataset = tf.data.Dataset.from_tensor_slices((\n        {'a': inputs_a, 'b': inputs_b},\n        {'c': outputs_c, 'd': outputs_d})).batch(32)\n    keras_inputs_a = tf_keras.Input(shape=(1,), dtype=tf.float32, name='a')\n    keras_inputs_b = tf_keras.Input(shape=(1,), dtype=tf.float32, name='b')\n    keras_outputs_c = tf_keras.layers.Dense(units=1, name='c')(keras_inputs_a)\n    keras_outputs_d = tf_keras.layers.Dense(\n        units=1, name='d', activation='sigmoid')(keras_inputs_b)\n    keras_model = tf_keras.Model(\n        inputs={'a': keras_inputs_a, 'b': keras_inputs_b},\n        outputs={'c': keras_outputs_c, 'd': keras_outputs_d})\n    keras_model.compile('sgd', {'c': 'mse', 'd': 'binary_crossentropy'}, [])\n    keras_model.fit(dataset)\n\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=keras_model,\n        model_dir=tempfile.mkdtemp(dir=self._base_dir),\n        checkpoint_format=checkpoint_format,\n        export_outputs={'c': export_output.RegressionOutput,\n                        'd': export_output.ClassificationOutput})\n\n    def serving_input_receiver_fn():\n      feature_spec = {\n          'a': tf.io.FixedLenFeature([1], dtype=tf.dtypes.float32),\n          'b': tf.io.FixedLenFeature([1], dtype=tf.dtypes.float32),\n      }\n      return export_lib.build_parsing_serving_input_receiver_fn(feature_spec)\n\n    # Try immediately exporting, testing exported signatures\n    saved_model_dir = est_keras.export_saved_model(\n        tempfile.mkdtemp(dir=self._base_dir), serving_input_receiver_fn())\n    imported_est = tf.saved_model.load(saved_model_dir)\n    imported_signatures = imported_est.signatures\n    assert 'c' in imported_signatures\n    assert 'd' in imported_signatures\n    assert 'serving_default' in imported_signatures\n\n  @parameterized.named_parameters(\n      dict(testcase_name='object_ckpt', checkpoint_format='checkpoint'),\n      dict(testcase_name='name_ckpt', checkpoint_format='saver'))\n  def test_export_keras_estimator_unknown_signatures(self, checkpoint_format):\n    inputs_a = np.random.random((320, 1))\n    inputs_b = np.random.random((320, 1))\n    outputs_c = np.random.random((320, 1))\n    outputs_d = np.random.random((320, 1))\n\n    dataset = tf.data.Dataset.from_tensor_slices((\n        {'a': inputs_a, 'b': inputs_b},\n        {'c': outputs_c, 'd': outputs_d})).batch(32)\n    keras_inputs_a = tf_keras.Input(shape=(1,), dtype=tf.float32, name='a')\n    keras_inputs_b = tf_keras.Input(shape=(1,), dtype=tf.float32, name='b')\n    keras_outputs_c = tf_keras.layers.Dense(units=1, name='c')(keras_inputs_a)\n    keras_outputs_d = tf_keras.layers.Dense(\n        units=1, name='d', activation='sigmoid')(keras_inputs_b)\n    keras_model = tf_keras.Model(\n        inputs={'a': keras_inputs_a, 'b': keras_inputs_b},\n        outputs={'c': keras_outputs_c, 'd': keras_outputs_d})\n    keras_model.compile('sgd', {'c': 'mse', 'd': 'binary_crossentropy'}, [])\n    keras_model.fit(dataset)\n\n    with self.assertRaisesRegex(\n        keras_lib.FormattedKeyError,\n        r'Missed keys'):\n      est_keras = keras_lib.model_to_estimator(\n          keras_model=keras_model,\n          model_dir=tempfile.mkdtemp(dir=self._base_dir),\n          checkpoint_format=checkpoint_format,\n          export_outputs={'c': export_output.RegressionOutput,\n                          'p': export_output.ClassificationOutput})\n      def serving_input_receiver_fn():\n        feature_spec = {\n            'a': tf.io.FixedLenFeature([1], dtype=tf.dtypes.float32),\n            'b': tf.io.FixedLenFeature([1], dtype=tf.dtypes.float32),\n        }\n        return export_lib.build_parsing_serving_input_receiver_fn(feature_spec)\n\n      # Try immediately exporting, testing exported signatures\n      _ = est_keras.export_saved_model(\n          tempfile.mkdtemp(dir=self._base_dir), serving_input_receiver_fn())\n\n  def test_export_subclassed_model_retains_model_state(self):\n    keras_model, (x_train, y_train), (\n        _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(\n            model_type='subclass', is_evaluate=True)\n    keras_model.compile(\n        optimizer=tf.compat.v1.train.RMSPropOptimizer(1e-3),\n        loss='categorical_crossentropy',\n        metrics=['accuracy'])\n    keras_model.fit(x_train, y_train, epochs=1)\n    iterations = tf_keras.backend.get_value(keras_model.optimizer.iterations)\n    optimizer = keras_model.optimizer\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=keras_model, config=self._config, checkpoint_format='saver')\n    est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)\n\n    # Subclassed models resets the model object. Assert that attributes are\n    # properly restored.\n    iterations_after = tf_keras.backend.get_value(\n        keras_model.optimizer.iterations)\n    self.assertEqual(optimizer, keras_model.optimizer)\n    self.assertEqual(iterations, iterations_after)\n    # TODO(b/132839451): model.fit results in an error after model_to_estimator.\n    # keras_model.fit(x_train, y_train, epochs=1)\n\n  def test_warm_start_from_keras_ckpt(self):\n    keras_model, (x_train, y_train), (\n        _, _), train_input_fn, eval_input_fn = get_resource_for_simple_model(\n            model_type='functional', is_evaluate=True)\n    keras_model.compile(\n        optimizer=tf.compat.v1.train.RMSPropOptimizer(1e-3),\n        loss='categorical_crossentropy',\n        metrics=['accuracy'])\n    keras_model.fit(x_train, y_train, epochs=1)\n\n    warm_start_path = os.path.join(self._config.model_dir, 'keras',\n                                   'warm_start.ckpt')\n    keras_model.save_weights(warm_start_path)\n\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=keras_model, config=self._config, checkpoint_format='saver')\n\n    self.assertEqual(warm_start_path,\n                     est_keras._warm_start_settings.ckpt_to_initialize_from)\n    before_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)\n    est_keras.train(input_fn=train_input_fn, steps=_TRAIN_SIZE / 16)\n    after_eval_results = est_keras.evaluate(input_fn=eval_input_fn, steps=1)\n    self.assertLess(after_eval_results['loss'], before_eval_results['loss'])\n\n  def test_sample_weights(self):\n    # Create simple pass-through model\n    input_layer = tf_keras.layers.Input(shape=1, name='input_layer')\n    keras_model = tf_keras.models.Model(inputs=input_layer, outputs=input_layer)\n\n    keras_model.compile(loss='mean_absolute_error', optimizer='adam')\n\n    features = [[0.], [0], [1], [1]]\n    sample_weights = [0, .4, 1, 1]\n    targets = [[0], [1], [0], [1]]\n\n    expected_loss = keras_model.test_on_batch(\n        tf.constant(features), tf.constant(targets),\n        tf.constant(sample_weights))\n\n    def input_fn():\n      dataset = tf.compat.v1.data.Dataset.from_tensors(({\n          'features': features,\n          'sample_weights': sample_weights\n      }, targets))\n      return dataset\n\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir))\n    eval_results = est_keras.evaluate(input_fn, steps=1)\n    self.assertAllClose(expected_loss, eval_results['loss'])\n\n    # Test multiple with outputs and sample weights.\n    keras_model = tf_keras.models.Model(\n        inputs=input_layer, outputs=[input_layer, input_layer])\n    keras_model.compile(loss='mean_absolute_error', optimizer='adam')\n    expected_loss = keras_model.test_on_batch(\n        tf.constant(features),\n        [tf.constant(targets), tf.constant(targets)],\n        [tf.constant(sample_weights),\n         tf.constant(sample_weights)])[0]\n\n    def input_fn_multiple_targets():\n      dataset = tf.compat.v1.data.Dataset.from_tensors(\n          (features, sample_weights, targets))\n      dataset = dataset.map(lambda x, y, z: ({\n          'features': x,\n          'sample_weights': (y, y)\n      }, (z, z)))\n      return dataset\n\n    est_keras = keras_lib.model_to_estimator(\n        keras_model=keras_model, model_dir=tempfile.mkdtemp(dir=self._base_dir))\n    eval_results = est_keras.evaluate(input_fn_multiple_targets, steps=1)\n    self.assertAllClose(expected_loss, eval_results['loss'])\n\n  @parameterized.parameters([tf_keras_v2.layers.LSTM, tf_keras_v2.layers.GRU])\n  def test_model_to_estimator_with_rnn(self, layer):\n    # See https://github.com/tensorflow/tensorflow/issues/27750 for details.\n    timestep = 10\n    rnn_cell_size = 8\n\n    layers = [\n        tf_keras.layers.Reshape([timestep, 1], input_shape=[\n            timestep,\n        ]),\n        layer(rnn_cell_size, return_sequences=True),\n        layer(rnn_cell_size),\n        tf_keras.layers.Dense(1)\n    ]\n\n    model = tf_keras.models.Sequential(layers)\n    model.compile(loss='mse', optimizer='sgd')\n    keras_lib.model_to_estimator(\n        keras_model=model,\n        checkpoint_format='checkpoint',\n        model_dir=tempfile.mkdtemp(dir=self._base_dir))\n\n\ndef get_test_data(train_samples,\n                  test_samples,\n                  input_shape,\n                  num_classes,\n                  random_seed=None):\n  if random_seed is not None:\n    np.random.seed(random_seed)\n  num_sample = train_samples + test_samples\n  templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)\n  y = np.random.randint(0, num_classes, size=(num_sample,))\n  x = np.zeros((num_sample,) + input_shape, dtype=np.float32)\n  for i in range(num_sample):\n    x[i] = templates[y[i]] + np.random.normal(loc=0, scale=1., size=input_shape)\n  return ((x[:train_samples], y[:train_samples]),\n          (x[train_samples:], y[train_samples:]))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/mode_keys.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Exporting ModeKeys to tf.estimator namespace.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom tensorflow.python.saved_model.model_utils.mode_keys import EstimatorModeKeys as ModeKeys\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\nestimator_export('estimator.ModeKeys')(ModeKeys)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/model_fn.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Classes and methods related to model_fn.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.saved_model import model_utils as export_utils\nfrom tensorflow.python.tpu import tensor_tracer\nfrom tensorflow.python.util import function_utils\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\nLOSS_METRIC_KEY = 'loss'\nAVERAGE_LOSS_METRIC_KEY = 'average_loss'\n\n\n@estimator_export('estimator.EstimatorSpec')\nclass EstimatorSpec(\n    collections.namedtuple('EstimatorSpec', [\n        'mode', 'predictions', 'loss', 'train_op', 'eval_metric_ops',\n        'export_outputs', 'training_chief_hooks', 'training_hooks', 'scaffold',\n        'evaluation_hooks', 'prediction_hooks'\n    ])):\n  \"\"\"Ops and objects returned from a `model_fn` and passed to an `Estimator`.\n\n  `EstimatorSpec` fully defines the model to be run by an `Estimator`.\n  \"\"\"\n\n  def __new__(cls,\n              mode,\n              predictions=None,\n              loss=None,\n              train_op=None,\n              eval_metric_ops=None,\n              export_outputs=None,\n              training_chief_hooks=None,\n              training_hooks=None,\n              scaffold=None,\n              evaluation_hooks=None,\n              prediction_hooks=None):\n    \"\"\"Creates a validated `EstimatorSpec` instance.\n\n    Depending on the value of `mode`, different arguments are required. Namely\n\n    * For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`.\n    * For `mode == ModeKeys.EVAL`: required field is `loss`.\n    * For `mode == ModeKeys.PREDICT`: required fields are `predictions`.\n\n    model_fn can populate all arguments independent of mode. In this case, some\n    arguments will be ignored by an `Estimator`. E.g. `train_op` will be\n    ignored in eval and infer modes. Example:\n\n    ```python\n    def my_model_fn(features, labels, mode):\n      predictions = ...\n      loss = ...\n      train_op = ...\n      return tf.estimator.EstimatorSpec(\n          mode=mode,\n          predictions=predictions,\n          loss=loss,\n          train_op=train_op)\n    ```\n\n    Alternatively, model_fn can just populate the arguments appropriate to the\n    given mode. Example:\n\n    ```python\n    def my_model_fn(features, labels, mode):\n      if (mode == tf.estimator.ModeKeys.TRAIN or\n          mode == tf.estimator.ModeKeys.EVAL):\n        loss = ...\n      else:\n        loss = None\n      if mode == tf.estimator.ModeKeys.TRAIN:\n        train_op = ...\n      else:\n        train_op = None\n      if mode == tf.estimator.ModeKeys.PREDICT:\n        predictions = ...\n      else:\n        predictions = None\n\n      return tf.estimator.EstimatorSpec(\n          mode=mode,\n          predictions=predictions,\n          loss=loss,\n          train_op=train_op)\n    ```\n\n    Args:\n      mode: A `ModeKeys`. Specifies if this is training, evaluation or\n        prediction.\n      predictions: Predictions `Tensor` or dict of `Tensor`.\n      loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.\n      train_op: Op for the training step.\n      eval_metric_ops: Dict of metric results keyed by name.\n        The values of the dict can be one of the following: (1) instance of\n          `Metric` class. (2) Results of calling a metric function, namely a\n          `(metric_tensor, update_op)` tuple. `metric_tensor` should be\n          evaluated without any impact on state (typically is a pure computation\n          results based on variables.). For example, it should not trigger the\n          `update_op` or requires any input fetching.\n      export_outputs: Describes the output signatures to be exported to\n        `SavedModel` and used during serving.\n        A dict `{name: output}` where:\n        * name: An arbitrary name for this output.\n        * output: an `ExportOutput` object such as `ClassificationOutput`,\n          `RegressionOutput`, or `PredictOutput`. Single-headed models only need\n          to specify one entry in this dictionary. Multi-headed models should\n          specify one entry for each head, one of which must be named using\n          `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`.\n          If no entry is provided, a default `PredictOutput` mapping to\n          `predictions` will be created.\n      training_chief_hooks: Iterable of `tf.train.SessionRunHook` objects to run\n        on the chief worker during training.\n      training_hooks: Iterable of `tf.train.SessionRunHook` objects to run on\n        all workers during training.\n      scaffold: A `tf.train.Scaffold` object that can be used to set\n        initialization, saver, and more to be used in training.\n      evaluation_hooks: Iterable of `tf.train.SessionRunHook` objects to run\n        during evaluation.\n      prediction_hooks: Iterable of `tf.train.SessionRunHook` objects to run\n        during predictions.\n\n    Returns:\n      A validated `EstimatorSpec` object.\n\n    Raises:\n      ValueError: If validation fails.\n      TypeError: If any of the arguments is not the expected type.\n    \"\"\"\n    train_op = _validate_estimator_spec_train_op(train_op, mode)\n    loss = _validate_estimator_spec_loss(loss, mode)\n    predictions = _validate_estimator_spec_predictions(predictions, mode)\n    export_outputs = _validate_estimator_spec_export_outputs(\n        export_outputs, predictions, mode)\n    training_hooks = _validate_estimator_spec_hooks(training_hooks)\n    evaluation_hooks = _validate_estimator_spec_hooks(evaluation_hooks)\n    prediction_hooks = _validate_estimator_spec_hooks(prediction_hooks)\n    training_chief_hooks = _validate_estimator_spec_hooks(training_chief_hooks)\n    eval_metric_ops = _validate_eval_metric_ops(eval_metric_ops)\n    scaffold = _validate_scaffold(scaffold)\n\n    # By default, Tensor Tracer is not enabled and the block below is an no-op.\n    if tensor_tracer.TensorTracer.is_enabled() and train_op is not None:\n      # If Tensor Tracer is enabled via environment flags, loss and train_op\n      # will be used to determine the execution path that will be traced. A\n      # `tf.identity` of loss that enforces the execution of tracing ops will be\n      # returned.\n      tt = tensor_tracer.TensorTracer()\n      loss = tt.trace_cpu(tf.compat.v1.get_default_graph(), loss, train_op)\n\n    return super(EstimatorSpec, cls).__new__(\n        cls,\n        mode=mode,\n        predictions=predictions,\n        loss=loss,\n        train_op=train_op,\n        eval_metric_ops=eval_metric_ops,\n        export_outputs=export_outputs,\n        training_chief_hooks=training_chief_hooks,\n        training_hooks=training_hooks,\n        scaffold=scaffold,\n        evaluation_hooks=evaluation_hooks,\n        prediction_hooks=prediction_hooks)\n\n  def _replace(self, **kwds):\n    \"\"\"Return a new EstimatorSpec replacing specified fields with new values.\"\"\"\n    if 'mode' in kwds:\n      if self.mode != kwds['mode']:\n        raise ValueError('mode of EstimatorSpec cannot be changed.')\n    new_fields = map(kwds.pop, self._fields, list(self))\n    return EstimatorSpec(*new_fields)\n\n\nclass _TPUEstimatorSpec(\n    collections.namedtuple('TPUEstimatorSpec', [\n        'mode', 'predictions', 'loss', 'train_op', 'eval_metrics',\n        'export_outputs', 'scaffold_fn', 'host_call', 'training_hooks',\n        'evaluation_hooks', 'prediction_hooks'\n    ])):\n  \"\"\"Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.\n\n  This is a simplified implementation of `tf.contrib.tpu.EstimatorSpec`. See\n  tensorflow/contrib/tpu/python/tpu/tpu_estimator.py for more detailed\n  documentation.\n  \"\"\"\n\n  def __new__(cls,\n              mode,\n              predictions=None,\n              loss=None,\n              train_op=None,\n              eval_metrics=None,\n              export_outputs=None,\n              scaffold_fn=None,\n              host_call=None,\n              training_hooks=None,\n              evaluation_hooks=None,\n              prediction_hooks=None):\n    \"\"\"Creates a `_TPUEstimatorSpec` instance.\"\"\"\n    train_op = _validate_estimator_spec_train_op(train_op, mode)\n    loss = _validate_estimator_spec_loss(loss, mode)\n    predictions = _validate_estimator_spec_predictions(predictions, mode)\n    export_outputs = _validate_estimator_spec_export_outputs(\n        export_outputs, predictions, mode)\n    training_hooks = _validate_estimator_spec_hooks(training_hooks)\n    evaluation_hooks = _validate_estimator_spec_hooks(evaluation_hooks)\n    prediction_hooks = _validate_estimator_spec_hooks(prediction_hooks)\n    return super(_TPUEstimatorSpec, cls).__new__(\n        cls,\n        mode=mode,\n        predictions=predictions,\n        loss=loss,\n        train_op=train_op,\n        eval_metrics=eval_metrics,\n        export_outputs=export_outputs,\n        scaffold_fn=scaffold_fn,\n        host_call=host_call,\n        training_hooks=training_hooks,\n        evaluation_hooks=evaluation_hooks,\n        prediction_hooks=prediction_hooks)\n\n  def as_estimator_spec(self):\n    \"\"\"Creates an equivalent `EstimatorSpec` used by CPU train/eval.\"\"\"\n    if not self.eval_metrics:\n      eval_metric_ops = None\n    else:\n      metric_fn, tensors = self.eval_metrics\n      eval_metric_ops = metric_fn(**tensors)\n    return EstimatorSpec(\n        mode=self.mode,\n        predictions=self.predictions,\n        loss=self.loss,\n        train_op=self.train_op,\n        eval_metric_ops=eval_metric_ops,\n        export_outputs=self.export_outputs,\n        training_hooks=self.training_hooks,\n        evaluation_hooks=self.evaluation_hooks,\n        prediction_hooks=self.prediction_hooks)\n\n\n# Used to generate possible error causes if the user provides a `Tensor` to an\n# EstimatorSpec that is not in the default graph.\n_default_graph_error_message_template = (\n    '{0} with \"{1}\" must be from the default graph. '\n    'Possible causes of this error include: \\n\\n'\n    '1) {0} was created outside the context of the default graph.'\n    '\\n\\n'\n    '2) The object passed through to EstimatorSpec was not created '\n    'in the most recent call to \"model_fn\".')\n\n\ndef _validate_estimator_spec_train_op(train_op, mode):\n  \"\"\"Validate train_op inputs for EstimatorSpec or TPUEstimatorSpec.\n\n  Args:\n    train_op: Op for the training step.\n    mode: A `ModeKeys`. Used to determine whether the train_op is acceptable for\n      use in the current mode; for example, if we are not training, this can be\n      None.\n\n  Returns:\n    train_op: Op for the training step.\n\n  Raises:\n    ValueError: If no train_op is passed during training.\n    TypeError:  If:\n                - train_op is neither a `Tensor` nor an Op.\n                - train_op is not part of the default graph.\n  \"\"\"\n  if train_op is None:\n    if mode == ModeKeys.TRAIN:\n      raise ValueError('Missing train_op.')\n  else:\n    default_graph = tf.compat.v1.get_default_graph()\n    _check_is_tensor_or_operation(train_op, 'train_op')\n    if isinstance(train_op, tf.Variable):\n      train_op = train_op.op\n    if not (tf.executing_eagerly() or train_op.graph is default_graph):\n      raise ValueError(\n          _default_graph_error_message_template.format('train_op',\n                                                       train_op.name))\n  return train_op\n\n\ndef _validate_estimator_spec_loss(loss, mode):\n  \"\"\"Validate loss inputs for EstimatorSpec or TPUEstimatorSpec.\n\n  Args:\n    loss: Training loss `Tensor`. Must either be scalar, or with shape `[1]`.\n    mode: A `ModeKeys`. Used to determine whether the loss is acceptable for use\n      in the current mode; for example, None is acceptable if we are not\n      training or evaluating.\n\n  Returns:\n    loss: Training loss `Tensor`.\n\n  Raises:\n    ValueError: If the loss `Tensor` is not appropriately formatted.\n    TypeError:  If:\n                - a non-`Tensor`, non-None input is passed.\n                - the loss `Tensor` is not part of the default graph.\n  \"\"\"\n  if loss is None:\n    if mode in (ModeKeys.TRAIN, ModeKeys.EVAL):\n      raise ValueError('Missing loss.')\n  else:\n    default_graph = tf.compat.v1.get_default_graph()\n    # Loss must be a tensor.\n    loss = _check_is_tensor(loss, 'loss')\n    loss_shape = loss.get_shape()\n    if loss_shape.num_elements() not in (None, 1):\n      raise ValueError('Loss must be scalar, given: {}'.format(loss))\n    if not loss_shape.is_compatible_with(tf.TensorShape([])):\n      loss = tf.reshape(loss, [])\n    if not (tf.executing_eagerly() or loss.graph is default_graph):\n      raise ValueError(\n          _default_graph_error_message_template.format('loss', loss.name))\n  return loss\n\n\ndef _validate_estimator_spec_predictions(predictions, mode):\n  \"\"\"Validate predictions inputs for EstimatorSpec or TPUEstimatorSpec.\n\n  Args:\n    predictions: Predictions `Tensor` or dict of `Tensor`.\n    mode: A `ModeKeys`. Used to determine whether the predictions are acceptable\n      for use in the current mode; None is acceptable if we are not making\n      predictions.\n\n  Returns:\n    predictions: Predictions `Tensor` or dict of `Tensor`.\n\n  Raises:\n    ValueError: If:\n      - predictions is None and we are in predict mode.\n      - predictions `Tensor` is not in default_graph or else it is a dict of\n        `Tensor` where at least one is not in default_graph.\n    TypeError:  If predictions is not a `Tensor` or dict of `Tensor`.\n  \"\"\"\n  if predictions is None:\n    if mode == ModeKeys.PREDICT:\n      raise ValueError('Missing predictions.')\n    predictions = {}\n  else:\n    default_graph = tf.compat.v1.get_default_graph()\n    if isinstance(predictions, dict):\n      predictions = {\n          k: _check_is_tensor(v, 'predictions[{}]'.format(k))\n          for k, v in six.iteritems(predictions)\n      }\n      if not tf.executing_eagerly():\n        for key, value in six.iteritems(predictions):\n          if value.graph is not default_graph:\n            raise ValueError(\n                _default_graph_error_message_template.format(\n                    'prediction values', '{0}: {1}'.format(key, value.name)))\n    else:\n      # Predictions should be a tensor.\n      predictions = _check_is_tensor(predictions, 'predictions')\n      if not (tf.executing_eagerly() or predictions.graph is default_graph):\n        raise ValueError(\n            _default_graph_error_message_template.format(\n                'prediction values', predictions.name))\n  return predictions\n\n\ndef _validate_estimator_spec_export_outputs(export_outputs, predictions, mode):\n  \"\"\"Validate export_outputs inputs for EstimatorSpec or TPUEstimatorSpec.\n\n  Args:\n    export_outputs: Describes the output signatures to be exported to\n      `SavedModel` and used during serving.\n      A dict `{name: output}` where:\n      * name: An arbitrary name for this output.\n      * output: an `ExportOutput` object such as `ClassificationOutput`\n        `RegressionOutput`, or `PredictOutput`. Single-headed models should only\n        need to specify one entry in this dictionary. Multi-headed models should\n        specify one entry for each head, one of which must be named using\n        `tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY`.\n        If no entry is provided, a default `PredictOutput` mapping to\n        predictions will be created.\n    predictions: Predictions `Tensor` or dict of `Tensor`. Used in generation of\n      default outputs.\n    mode: A `ModeKeys`. Used to determine whether to validate at all; if the\n      EstimatorSpec is not for making predictions we can skip validation.\n\n  Returns:\n    ValueError: If validation fails.\n    TypeError: If the export_outputs is not a dict or the values of the dict are\n               not instances of type `ExportOutput`.\n  \"\"\"\n  if mode == ModeKeys.PREDICT:\n    export_outputs = export_utils.get_export_outputs(export_outputs,\n                                                     predictions)\n  return export_outputs\n\n\ndef _validate_estimator_spec_hooks(hooks):\n  \"\"\"Validate SessionRunHooks for use in EstimatorSpec or TPUEstimatorSpec.\n\n  Args:\n    hooks: Iterable of `tf.train.SessionRunHook` objects to run on all workers.\n\n  Returns:\n    hooks: Iterable of `tf.train.SessionRunHook` objects.\n\n  Raises:\n    ValueError: If validation fails.\n    TypeError:  If any element of the iterable is not a SessionRunHook.\n  \"\"\"\n  hooks = tuple(hooks or [])\n\n  for hook in hooks:\n    if not isinstance(hook, tf.compat.v1.train.SessionRunHook):\n      raise TypeError(\n          'All hooks must be SessionRunHook instances, given: {}'.format(hook))\n  return hooks\n\n\ndef _validate_eval_metric_ops(eval_metric_ops):\n  \"\"\"Validate eval_metric_ops for use in EstimatorSpec.\n\n  Args:\n    eval_metric_ops: Dict of metric results keyed by name.\n      The values of the dict can be one of the following: (1) instance of\n        `Metric` class. (2) Results of calling a metric_function, namely a\n        `(metric_tensor, update_op)` tuple. `metric_tensor` should be evaluated\n        without any impact on state (typically it is a pure computation based on\n        variables.). For example, it should not trigger the `update_op` or\n        require any input fetching.\n\n  Returns:\n    eval_metric_ops: Dict of metric results keyed by name.\n\n  Raises:\n    ValueError:  If:\n     - one of the eval_metric_ops `Metric` objects has no updates.\n     - there is at least one `Metric` update or result, `Tensor`, or Op that is\n       not in the default graph.\n    TypeError:   If:\n     - eval_metric_ops is not a dict or None.\n     - an element of eval_metric_ops is not a `Metric` or a 2-tuple.\n     - an element of eval_metric_ops has a sub-element that is not a `Tensor` or\n       an Op.\n  \"\"\"\n  if eval_metric_ops is None:\n    eval_metric_ops = {}\n  else:\n    if not isinstance(eval_metric_ops, dict):\n      raise TypeError(\n          'eval_metric_ops must be a dict, given: {}'.format(eval_metric_ops))\n    for key, value in six.iteritems(eval_metric_ops):\n      # TODO(psv): When we deprecate the old metrics, throw an error here if\n      # the value is not an instance of `Metric` class.\n      if isinstance(value, tf_keras.metrics.Metric):\n        if not value.updates:  # Check if metric updates are available.\n          raise ValueError(\n              'Please call update_state(...) on the \"{metric_name}\" metric'\n              .format(metric_name=value.name))\n      else:\n        if not isinstance(value, tuple) or len(value) != 2:\n          raise TypeError(\n              'Values of eval_metric_ops must be (metric_value, update_op) '\n              'tuples, given: {} for key: {}'.format(value, key))\n  # Verify all tensors and ops are from default graph.\n  default_graph = tf.compat.v1.get_default_graph()\n  for key, value in list(six.iteritems(eval_metric_ops)):\n    if isinstance(value, tf_keras.metrics.Metric):\n      values_to_check = value.updates[:]\n      values_to_check.append(value.result())\n    else:\n      values_to_check = tf.nest.flatten(value)\n    for val in values_to_check:\n      if not (tf.executing_eagerly() or val.graph is default_graph):\n        raise ValueError(\n            _default_graph_error_message_template.format(\n                'eval_metric_ops', '{0}: {1}'.format(key, val.name)))\n  # Metric variables are by default not added to any collections. The variables\n  # are appended to the LOCAL_VARIABLES collection for initialization, and\n  # METRIC_VARIABLES for TFMA compatibility. Note that although collections are\n  # officially deprecated in TensorFlow 2, Estimators will continue using\n  # collections as long as it supports V1 graph mode.\n  vars_to_add = set()\n  for key, value in six.iteritems(eval_metric_ops):\n    if isinstance(value, tf_keras.metrics.Metric):\n      vars_to_add.update(value.variables)\n      # Convert Metric instances to (value_tensor, update_op) tuple.\n      eval_metric_ops[key] = (value.result(), value.updates[0])\n  _update_variable_collection(tf.compat.v1.GraphKeys.LOCAL_VARIABLES,\n                              vars_to_add)\n  _update_variable_collection(tf.compat.v1.GraphKeys.METRIC_VARIABLES,\n                              vars_to_add)\n\n  return eval_metric_ops\n\n\ndef _update_variable_collection(collection_name, vars_to_add):\n  \"\"\"Add variables to collection.\"\"\"\n  collection = set(tf.compat.v1.get_collection(collection_name))\n  # Skip variables that are in the collection already.\n  vars_to_add = vars_to_add.difference(collection)\n  for v in vars_to_add:\n    tf.compat.v1.add_to_collection(collection_name, v)\n\n\ndef _validate_scaffold(scaffold):\n  \"\"\"Validate scaffold input for EstimatorSpec.\n\n  Args:\n    scaffold: A `tf.train.Scaffold` object that can be used to set\n      initialization, saver, and more to be used in training.\n\n  Returns:\n    scaffold: A `tf.train.Scaffold` object. If no scaffold is provided, then a\n      default is generated.\n\n  Raises:\n    TypeError: If the scaffold is not of type `monitored_session.Scaffold`\n      or None.\n  \"\"\"\n  scaffold = scaffold or tf.compat.v1.train.Scaffold()\n  if not isinstance(scaffold, tf.compat.v1.train.Scaffold):\n    raise TypeError(\n        'scaffold must be tf.train.Scaffold. Given: {}'.format(scaffold))\n  return scaffold\n\n\ndef _check_is_tensor_or_operation(x, name):\n  # TODO(b/154650521): Use tf.Tensor instead of core.Tensor.\n  if not isinstance(x, (tf.Operation, tf.compat.v2.__internal__.types.Tensor)):\n    raise TypeError('{} must be Operation or Tensor, given: {}'.format(name, x))\n\n\ndef _check_is_tensor(x, tensor_name):\n  \"\"\"Returns `x` if it is a `Tensor`, raises TypeError otherwise.\"\"\"\n  if not isinstance(x, tf.compat.v2.__internal__.types.Tensor):\n    raise TypeError('{} must be Tensor, given: {}'.format(tensor_name, x))\n  return x\n\n\n@estimator_export('estimator.experimental.call_logit_fn')\ndef call_logit_fn(logit_fn, features, mode, params, config):\n  \"\"\"Calls logit_fn (experimental).\n\n  THIS FUNCTION IS EXPERIMENTAL. Keras layers/models are the recommended APIs\n  for logit and model composition.\n\n  A utility function that calls the provided logit_fn with the relevant subset\n  of provided arguments. Similar to tf.estimator._call_model_fn().\n\n  Args:\n    logit_fn: A logit_fn as defined above.\n    features: The features dict.\n    mode: TRAIN / EVAL / PREDICT ModeKeys.\n    params: The hyperparameter dict.\n    config: The configuration object.\n\n  Returns:\n    A logit Tensor, the output of logit_fn.\n\n  Raises:\n    ValueError: if logit_fn does not return a Tensor or a dictionary mapping\n      strings to Tensors.\n  \"\"\"\n  logit_fn_args = function_utils.fn_args(logit_fn)\n  kwargs = {}\n  if 'mode' in logit_fn_args:\n    kwargs['mode'] = mode\n  if 'params' in logit_fn_args:\n    kwargs['params'] = params\n  if 'config' in logit_fn_args:\n    kwargs['config'] = config\n  logit_fn_results = logit_fn(features=features, **kwargs)\n\n  result_is_valid_dictionary = (\n      isinstance(logit_fn_results, dict) and\n      all([(isinstance(k, six.string_types) and isinstance(v, tf.Tensor))\n           for k, v in six.iteritems(logit_fn_results)]))\n  result_is_tensor = isinstance(logit_fn_results, tf.Tensor)\n\n  if not (result_is_valid_dictionary or result_is_tensor):\n    raise ValueError('logit_fn should return a Tensor or a dictionary mapping '\n                     'strings to Tensors.  logit_fn returned: %s' %\n                     logit_fn_results)\n\n  return logit_fn_results\n\n\n_VALID_MODEL_FN_ARGS = set(\n    ['features', 'labels', 'mode', 'params', 'self', 'config'])\n\n\ndef verify_model_fn_args(model_fn, params):\n  \"\"\"Verifies `model_fn` arguments.\"\"\"\n  args = set(function_utils.fn_args(model_fn))\n  if 'features' not in args:\n    raise ValueError('model_fn (%s) must include features argument.' % model_fn)\n  if params is not None and 'params' not in args:\n    raise ValueError('model_fn (%s) does not include params argument, '\n                     'but params (%s) is passed to Estimator.' %\n                     (model_fn, params))\n  if params is None and 'params' in args:\n    tf.compat.v1.logging.warn(\n        'Estimator\\'s model_fn (%s) includes params '\n        'argument, but params are not passed to Estimator.', model_fn)\n  non_valid_args = list(args - _VALID_MODEL_FN_ARGS)\n  if non_valid_args:\n    raise ValueError('model_fn (%s) has following not expected args: %s' %\n                     (model_fn, non_valid_args))\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/model_fn_test.py",
    "content": "# Copyright 2016 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for model_fn.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import model_fn\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.mode_keys import ModeKeys\n\n\nclass _FakeHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Fake implementation of `SessionRunHook`.\"\"\"\n\n\nclass _InvalidHook(object):\n  \"\"\"Invalid hook (not a subclass of `SessionRunHook`).\"\"\"\n\n\nclass _InvalidScaffold(object):\n  \"\"\"Invalid scaffold (not a subclass of `Scaffold`).\"\"\"\n\n\nclass EstimatorSpecTrainTest(tf.test.TestCase):\n  \"\"\"Tests EstimatorSpec in train mode.\"\"\"\n\n  def testRequiredArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all required arguments are set.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.TRAIN, loss=tf.constant(1.), train_op=tf.no_op())\n\n  def testAllArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all arguments are set.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      predictions = {'loss': loss}\n      classes = tf.constant('hello')\n      metric_obj = tf_keras.metrics.Mean()\n      metric_obj.update_state(loss)\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.TRAIN,\n          predictions=predictions,\n          loss=loss,\n          train_op=tf.no_op(),\n          eval_metric_ops={\n              'loss': (tf.no_op(), loss),\n              'mean': metric_obj,\n          },\n          export_outputs={\n              'head_name': export_output.ClassificationOutput(classes=classes)\n          },\n          training_chief_hooks=[_FakeHook()],\n          training_hooks=[_FakeHook()],\n          scaffold=tf.compat.v1.train.Scaffold(),\n          evaluation_hooks=[_FakeHook()],\n          prediction_hooks=[_FakeHook()])\n\n  def testLossNumber(self):\n    \"\"\"Tests that error is raised when loss is a number (not Tensor).\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN, loss=1., train_op=tf.no_op())\n\n  def testLoss1DTensor(self):\n    \"\"\"Tests that no errors are raised when loss is 1D tensor.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.TRAIN, loss=tf.constant([1.]), train_op=tf.no_op())\n\n  def testLossMissing(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(ValueError, 'Missing loss'):\n        model_fn.EstimatorSpec(mode=ModeKeys.TRAIN, train_op=tf.no_op())\n\n  def testLossNotScalar(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN,\n            loss=tf.constant([1., 2.]),\n            train_op=tf.no_op())\n\n  def testLossSparseTensor(self):\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.sparse.SparseTensor(indices=[[0]], values=[0.], dense_shape=[1])\n      with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN, loss=loss, train_op=tf.no_op())\n\n  def testLossFromDifferentGraph(self):\n    with tf.Graph().as_default():\n      loss = tf.constant(1.)\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(ValueError,\n                                   'must be from the default graph'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN, loss=loss, train_op=tf.no_op())\n\n  def testTrainOpMissing(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(ValueError, 'Missing train_op'):\n        model_fn.EstimatorSpec(mode=ModeKeys.TRAIN, loss=tf.constant(1.))\n\n  def testTrainOpNotOperationAndTensor(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(TypeError,\n                                   'train_op must be Operation or Tensor'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN,\n            loss=tf.constant(1.),\n            train_op='Not an Operation or Tensor')\n\n  def testTrainOpFromDifferentGraph(self):\n    with tf.Graph().as_default():\n      train_op = tf.no_op()\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(ValueError,\n                                   'must be from the default graph'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN, loss=tf.constant(1.), train_op=train_op)\n\n  def testTrainingChiefHookInvalid(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(\n          TypeError, 'All hooks must be SessionRunHook instances'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN,\n            loss=tf.constant(1.),\n            train_op=tf.no_op(),\n            training_chief_hooks=[_InvalidHook()])\n\n  def testTrainingHookInvalid(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(\n          TypeError, 'All hooks must be SessionRunHook instances'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN,\n            loss=tf.constant(1.),\n            train_op=tf.no_op(),\n            training_hooks=[_InvalidHook()])\n\n  def testScaffoldInvalid(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(TypeError,\n                                   r'scaffold must be tf\\.train\\.Scaffold'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.TRAIN,\n            loss=tf.constant(1.),\n            train_op=tf.no_op(),\n            scaffold=_InvalidScaffold())\n\n  def testReturnDefaultScaffold(self):\n    with tf.Graph().as_default(), self.cached_session():\n      estimator_spec = model_fn.EstimatorSpec(\n          mode=ModeKeys.TRAIN, loss=tf.constant(1.), train_op=tf.no_op())\n      self.assertIsNotNone(estimator_spec.scaffold)\n\n\nclass EstimatorSpecEvalTest(tf.test.TestCase):\n  \"\"\"Tests EstimatorSpec in eval mode.\"\"\"\n\n  def testRequiredArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all required arguments are set.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)\n\n  def testAllArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all arguments are set.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      predictions = {'loss': loss}\n      classes = tf.constant('hello')\n      metric_obj = tf_keras.metrics.Mean()\n      metric_obj.update_state(loss)\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL,\n          predictions=predictions,\n          loss=loss,\n          train_op=tf.no_op(),\n          eval_metric_ops={\n              'loss': (tf.no_op(), loss),\n              'mean': metric_obj,\n          },\n          export_outputs={\n              'head_name': export_output.ClassificationOutput(classes=classes)\n          },\n          training_chief_hooks=[_FakeHook()],\n          training_hooks=[_FakeHook()],\n          scaffold=tf.compat.v1.train.Scaffold(),\n          evaluation_hooks=[_FakeHook()])\n\n  def testEvaluationHookInvalid(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(\n          TypeError, 'All hooks must be SessionRunHook instances'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            loss=tf.constant(1.),\n            evaluation_hooks=[_InvalidHook()])\n\n  def testTupleMetric(self):\n    \"\"\"Tests that no errors are raised when a metric is tuple-valued.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL,\n          loss=loss,\n          eval_metric_ops={\n              'some_metric': ((loss, loss, (tf.constant(2), loss)), tf.no_op())\n          })\n\n  def testLoss1DTensor(self):\n    \"\"\"Tests that no errors are raised when loss is 1D tensor.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant([1.])\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)\n\n  def testLossNumber(self):\n    \"\"\"Tests that error is raised when loss is a number (not Tensor).\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL, predictions={'loss': tf.constant(1.)}, loss=1.)\n\n  def testLossMissing(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(ValueError, 'Missing loss'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL, predictions={'loss': tf.constant(1.)})\n\n  def testLossNotScalar(self):\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant([1., 2.])\n      with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)\n\n  def testLossSparseTensor(self):\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.sparse.SparseTensor(indices=[[0]], values=[0.], dense_shape=[1])\n      with self.assertRaisesRegexp(TypeError, 'loss must be Tensor'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            predictions={'prediction': tf.constant(1.)},\n            loss=loss)\n\n  def testLossFromDifferentGraph(self):\n    with tf.Graph().as_default():\n      loss = tf.constant(1.)\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(ValueError,\n                                   'must be from the default graph'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            predictions={'prediction': tf.constant(1.)},\n            loss=loss)\n\n  def testReplaceRaisesConstructorChecks(self):\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      spec = model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)\n      with self.assertRaisesRegexp(ValueError, 'Loss must be scalar'):\n        spec._replace(loss=tf.constant([1., 2.]))\n\n  def testReplaceDoesReplace(self):\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      spec = model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)\n      new_spec = spec._replace(predictions={'m': loss})\n      self.assertEqual(['m'], list(new_spec.predictions.keys()))\n\n  def testReplaceNotAllowModeChange(self):\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      spec = model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL, predictions={'loss': loss}, loss=loss)\n      spec._replace(mode=ModeKeys.EVAL)\n      with self.assertRaisesRegexp(ValueError,\n                                   'mode of EstimatorSpec cannot be changed'):\n        spec._replace(mode=ModeKeys.TRAIN)\n\n  def testPredictionsMissingIsOkay(self):\n    with tf.Graph().as_default(), self.cached_session():\n      model_fn.EstimatorSpec(mode=ModeKeys.EVAL, loss=tf.constant(1.))\n\n  def testPredictionsTensor(self):\n    \"\"\"Tests that no error is raised when predictions is Tensor (not dict).\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      model_fn.EstimatorSpec(mode=ModeKeys.EVAL, predictions=loss, loss=loss)\n\n  def testPredictionsNumber(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(TypeError,\n                                   r'predictions\\[number\\] must be Tensor'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            predictions={'number': 1.},\n            loss=tf.constant(1.))\n\n  def testPredictionsSparseTensor(self):\n    with tf.Graph().as_default(), self.cached_session():\n      predictions = {\n          'sparse':\n              tf.sparse.SparseTensor(\n                  indices=[[0]], values=[0.], dense_shape=[1])\n      }\n      with self.assertRaisesRegexp(TypeError,\n                                   r'predictions\\[sparse\\] must be Tensor'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL, predictions=predictions, loss=tf.constant(1.))\n\n  def testPredictionsFromDifferentGraph(self):\n    with tf.Graph().as_default():\n      predictions = {'loss': tf.constant(1.)}\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(ValueError,\n                                   'must be from the default graph'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL, predictions=predictions, loss=tf.constant(1.))\n\n  def testEvalMetricOpsNoDict(self):\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      with self.assertRaisesRegexp(TypeError, 'eval_metric_ops must be a dict'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            predictions={'loss': loss},\n            loss=loss,\n            eval_metric_ops=loss)\n\n  def testEvalMetricOpsNoTuple(self):\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      with self.assertRaisesRegexp(\n          TypeError,\n          (r'Values of eval_metric_ops must be \\(metric_value, update_op\\) '\n           'tuples')):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            predictions={'loss': loss},\n            loss=loss,\n            eval_metric_ops={'loss': loss})\n\n  def testEvalMetricOpsFromDifferentGraphWithMetricTuple(self):\n    with tf.Graph().as_default():\n      eval_metric_ops = {'loss': (tf.no_op(), tf.constant(1.))}\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      with self.assertRaisesRegexp(ValueError,\n                                   'must be from the default graph'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            predictions={'loss': loss},\n            loss=loss,\n            eval_metric_ops=eval_metric_ops)\n\n  def testEvalMetricOpsFromDifferentGraphWithMetricObject(self):\n    with tf.Graph().as_default():\n      metric_obj = tf_keras.metrics.Mean()\n      metric_obj.update_state(tf.constant(1.))\n      eval_metric_ops = {'metric': metric_obj}\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      with self.assertRaisesRegexp(ValueError,\n                                   'must be from the default graph'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            predictions={'loss': loss},\n            loss=loss,\n            eval_metric_ops=eval_metric_ops)\n\n  def testEvalMetricOpsWithoutUpdates(self):\n    with tf.Graph().as_default():\n      eval_metric_ops = {'mean': tf_keras.metrics.Mean()}\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      with self.assertRaisesRegexp(ValueError, 'Please call update_state(...)'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.EVAL,\n            predictions={'loss': loss},\n            loss=loss,\n            eval_metric_ops=eval_metric_ops)\n\n  def testMetricVariablesAddedToCollections(self):\n\n    def in_collection(collection_name, variables):\n      \"\"\"Returns whether all variables are in the collection.\"\"\"\n      return set(tf.compat.v1.get_collection(collection_name)).issuperset(\n          set(variables))\n\n    with tf.Graph().as_default():\n      metric_obj = tf_keras.metrics.Mean()\n      metric_obj.update_state(tf.constant(1.))\n      self.assertFalse(\n          in_collection(tf.compat.v1.GraphKeys.LOCAL_VARIABLES,\n                        metric_obj.variables))\n      self.assertFalse(\n          in_collection(tf.compat.v1.GraphKeys.METRIC_VARIABLES,\n                        metric_obj.variables))\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.EVAL,\n          predictions=tf.constant(1.),\n          loss=tf.constant(1.),\n          eval_metric_ops={'metric': metric_obj})\n      self.assertTrue(\n          in_collection(tf.compat.v1.GraphKeys.LOCAL_VARIABLES,\n                        metric_obj.variables))\n      self.assertTrue(\n          in_collection(tf.compat.v1.GraphKeys.METRIC_VARIABLES,\n                        metric_obj.variables))\n\n\nclass EstimatorSpecInferTest(tf.test.TestCase):\n  \"\"\"Tests EstimatorSpec in infer mode.\"\"\"\n\n  def testRequiredArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all required arguments are set.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.PREDICT, predictions={'loss': tf.constant(1.)})\n\n  def testAllArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all arguments are set.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      loss = tf.constant(1.)\n      predictions = {'loss': loss}\n      classes = tf.constant('hello')\n      metric_obj = tf_keras.metrics.Mean()\n      metric_obj.update_state(loss)\n      model_fn.EstimatorSpec(\n          mode=ModeKeys.PREDICT,\n          predictions=predictions,\n          loss=loss,\n          train_op=tf.no_op(),\n          eval_metric_ops={\n              'loss': (tf.no_op(), loss),\n              'mean': metric_obj,\n          },\n          export_outputs={\n              'head_name': export_output.ClassificationOutput(classes=classes)\n          },\n          training_chief_hooks=[_FakeHook()],\n          training_hooks=[_FakeHook()],\n          scaffold=tf.compat.v1.train.Scaffold(),\n          evaluation_hooks=[_FakeHook()],\n          prediction_hooks=[_FakeHook()])\n\n  def testPredictionHookInvalid(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(\n          TypeError, 'All hooks must be SessionRunHook instances'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.PREDICT,\n            predictions=tf.constant(1.),\n            prediction_hooks=[_InvalidHook()])\n\n  def testPredictionsMissing(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(ValueError, 'Missing predictions'):\n        model_fn.EstimatorSpec(mode=ModeKeys.PREDICT)\n\n  def testPredictionsTensor(self):\n    \"\"\"Tests that no error is raised when predictions is Tensor (not dict).\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      model_fn.EstimatorSpec(mode=ModeKeys.PREDICT, predictions=tf.constant(1.))\n\n  def testPredictionsNumber(self):\n    with tf.Graph().as_default(), self.cached_session():\n      with self.assertRaisesRegexp(TypeError,\n                                   r'predictions\\[number\\] must be Tensor'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.PREDICT, predictions={'number': 1.})\n\n  def testPredictionsSparseTensor(self):\n    with tf.Graph().as_default(), self.cached_session():\n      predictions = {\n          'sparse':\n              tf.sparse.SparseTensor(\n                  indices=[[0]], values=[0.], dense_shape=[1])\n      }\n      with self.assertRaisesRegexp(TypeError,\n                                   r'predictions\\[sparse\\] must be Tensor'):\n        model_fn.EstimatorSpec(mode=ModeKeys.PREDICT, predictions=predictions)\n\n  def testExportOutputsNoDict(self):\n    with tf.Graph().as_default(), self.cached_session():\n      predictions = {'loss': tf.constant(1.)}\n      classes = tf.constant('hello')\n      with self.assertRaisesRegexp(TypeError,\n                                   '[`]*export_outputs[`]* must be dict'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs=export_output.ClassificationOutput(classes=classes))\n\n  def testExportOutputsValueNotExportOutput(self):\n    with tf.Graph().as_default(), self.cached_session():\n      predictions = {'loss': tf.constant(1.)}\n      with self.assertRaisesRegexp(\n          TypeError,\n          r'Values in [`]*export_outputs[`]* must be ExportOutput objects.'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs={'head_name': predictions})\n\n  def testExportOutputsSingleheadMissingDefault(self):\n    with tf.Graph().as_default(), self.cached_session():\n      predictions = {'loss': tf.constant(1.)}\n      output_1 = tf.constant([1.])\n      regression_output = export_output.RegressionOutput(value=output_1)\n      export_outputs = {\n          'head-1': regression_output,\n      }\n      estimator_spec = model_fn.EstimatorSpec(\n          mode=ModeKeys.PREDICT,\n          predictions=predictions,\n          export_outputs=export_outputs)\n      expected_export_outputs = {\n          tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: regression_output,\n          'head-1': regression_output,\n      }\n      self.assertEqual(expected_export_outputs, estimator_spec.export_outputs)\n\n  def testExportOutputsMultiheadWithDefault(self):\n    with tf.Graph().as_default(), self.cached_session():\n      predictions = {'loss': tf.constant(1.)}\n      output_1 = tf.constant([1.])\n      output_2 = tf.constant(['2'])\n      output_3 = tf.constant(['3'])\n      export_outputs = {\n          tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY:\n              export_output.RegressionOutput(value=output_1),\n          'head-2':\n              export_output.ClassificationOutput(classes=output_2),\n          'head-3':\n              export_output.PredictOutput(outputs={'some_output_3': output_3})\n      }\n      estimator_spec = model_fn.EstimatorSpec(\n          mode=ModeKeys.PREDICT,\n          predictions=predictions,\n          export_outputs=export_outputs)\n      self.assertEqual(export_outputs, estimator_spec.export_outputs)\n\n  def testExportOutputsMultiheadMissingDefault(self):\n    with tf.Graph().as_default(), self.cached_session():\n      predictions = {'loss': tf.constant(1.)}\n      output_1 = tf.constant([1.])\n      output_2 = tf.constant(['2'])\n      output_3 = tf.constant(['3'])\n      export_outputs = {\n          'head-1':\n              export_output.RegressionOutput(value=output_1),\n          'head-2':\n              export_output.ClassificationOutput(classes=output_2),\n          'head-3':\n              export_output.PredictOutput(outputs={'some_output_3': output_3})\n      }\n      with self.assertRaisesRegexp(\n          ValueError, 'Multiple [`]*export_outputs[`]* were provided'):\n        model_fn.EstimatorSpec(\n            mode=ModeKeys.PREDICT,\n            predictions=predictions,\n            export_outputs=export_outputs)\n\n  def testDefaultExportOutputCreated(self):\n    \"\"\"Ensure that a default PredictOutput is created for export.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      predictions = tf.constant(1.)\n      self._assertDefaultExportOutputForPredictions(predictions)\n\n  def testDefaultExportOutputCreatedDict(self):\n    \"\"\"Ensure that a default PredictOutput is created for export for dicts.\"\"\"\n    with tf.Graph().as_default(), self.cached_session():\n      predictions = {'loss': tf.constant(1.), 'score': tf.constant(10.)}\n      self._assertDefaultExportOutputForPredictions(predictions)\n\n  def _assertDefaultExportOutputForPredictions(self, predictions):\n    spec = model_fn.EstimatorSpec(\n        mode=ModeKeys.PREDICT, predictions=predictions)\n\n    expected = export_output.PredictOutput(predictions).outputs\n    serving_output = spec.export_outputs[\n        tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n    self.assertEqual(serving_output.outputs, expected)\n\n\nclass LogitFnTest(tf.test.TestCase):\n\n  def test_simple_call_logit_fn(self):\n\n    def dummy_logit_fn(features, mode):\n      if mode == ModeKeys.TRAIN:\n        return features['f1']\n      else:\n        return features['f2']\n\n    features = {'f1': tf.constant([[2., 3.]]), 'f2': tf.constant([[4., 5.]])}\n    logit_fn_result = model_fn.call_logit_fn(dummy_logit_fn, features,\n                                             ModeKeys.EVAL, 'fake_params',\n                                             'fake_config')\n    with self.cached_session():\n      self.assertAllClose([[4., 5.]], self.evaluate(logit_fn_result))\n\n  def test_simple_call_multi_logit_fn(self):\n\n    def dummy_logit_fn(features):\n      return {u'head1': features['f1'], 'head2': features['f2']}\n\n    features = {'f1': tf.constant([[2., 3.]]), 'f2': tf.constant([[4., 5.]])}\n    logit_fn_result = model_fn.call_logit_fn(dummy_logit_fn, features,\n                                             ModeKeys.TRAIN, 'fake_params',\n                                             'fake_config')\n    with self.cached_session():\n      self.assertAllClose([[2., 3.]], self.evaluate(logit_fn_result['head1']))\n      self.assertAllClose([[4., 5.]], self.evaluate(logit_fn_result['head2']))\n\n  def test_invalid_logit_fn_results(self):\n\n    def invalid_logit_fn(features, params):\n      return [\n          features['f1'] * params['input_multiplier'],\n          features['f2'] * params['input_multiplier']\n      ]\n\n    features = {'f1': tf.constant([[2., 3.]]), 'f2': tf.constant([[4., 5.]])}\n    params = {'learning_rate': 0.001, 'input_multiplier': 2.0}\n    with self.assertRaisesRegexp(\n        ValueError, 'logit_fn should return a Tensor or a dictionary mapping '\n        'strings to Tensors'):\n      model_fn.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params,\n                             'fake_config')\n\n  def test_invalid_logit_fn_results_dict(self):\n\n    def invalid_logit_fn(features):\n      return {'head1': features['f1'], 'head2': features['f2']}\n\n    features = {'f1': tf.constant([[2., 3.]]), 'f2': 'some string'}\n    with self.assertRaisesRegexp(\n        ValueError, 'logit_fn should return a Tensor or a dictionary mapping '\n        'strings to Tensors'):\n      model_fn.call_logit_fn(invalid_logit_fn, features, 'fake_mode',\n                             'fake_params', 'fake_config')\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/object_checkpointing_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Integration tests for Estimator + object checkpointing.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\n\nimport os\n# pylint: disable=g-import-not-at-top\ntry:\n  from tensorflow.python.checkpoint import checkpoint as util\nexcept ImportError:\n  # TODO(allenl): Remove this after cl/229814711 syncs\n  from tensorflow.python.training.checkpointable import util\n\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.export import export_lib\n\n\nclass SubclassedModel(tf_keras.models.Model):\n\n  def __init__(self):\n    super(SubclassedModel, self).__init__()\n    self.dense_one = tf_keras.layers.Dense(5)\n    self.dense_two = tf_keras.layers.Dense(1)\n\n  def call(self, inputs):\n    return self.dense_two(self.dense_one(inputs))\n\n\ndef _serving_input_receiver_fn():\n  receiver = tf.compat.v1.placeholder(\n      tf.dtypes.float32, shape=[None, 1], name='input')\n  return export_lib.ServingInputReceiver(\n      features={'feature': receiver}, receiver_tensors=receiver)\n\n\nclass ObjectCheckpointingTest(tf.test.TestCase):\n\n  def _make_estimator(self, model_dir):\n\n    def _model_fn(features, labels, mode):\n      del labels\n      model = SubclassedModel()\n      optimizer = tf_keras.optimizers.Adam(0.01)\n      checkpoint = util.Checkpoint(\n          step=tf.compat.v1.train.get_or_create_global_step(),\n          optimizer=optimizer,\n          model=model)\n      # Make the save counter to satisfy the assert_consumed() assertion later\n      checkpoint.save_counter  # pylint: disable=pointless-statement\n      with tf.GradientTape() as tape:\n        output = model(features['feature'])\n        loss = tf.math.reduce_sum(output)\n      variables = model.trainable_variables\n      gradients = tape.gradient(loss, variables)\n      train_op = tf.group(\n          optimizer.apply_gradients(zip(gradients, variables)),\n          checkpoint.step.assign_add(1))\n      return model_fn_lib.EstimatorSpec(\n          mode,\n          loss=loss,\n          train_op=train_op,\n          predictions=dict(\n              output=output,\n              bias=tf.tile(model.dense_two.bias[None, :],\n                           [tf.compat.v1.shape(output)[0], 1]),\n              step=tf.tile(checkpoint.step[None],\n                           [tf.compat.v1.shape(output)[0]])),\n          scaffold=tf.compat.v1.train.Scaffold(saver=checkpoint))\n\n    est = estimator_lib.EstimatorV2(model_fn=_model_fn, model_dir=model_dir)\n\n    def _input_map_fn(tensor):\n      \"\"\"Converts a tensor into `features, labels` format used by Estimator.\"\"\"\n      return {'feature': tensor}, tensor\n\n    def _input_fn():\n      return tf.compat.v1.data.Dataset.from_tensors(\n          [1.]).repeat().batch(10).map(_input_map_fn)\n\n    return est, _input_fn\n\n  def testTwoWayCompatibility(self):\n    save_model_dir = os.path.join(self.get_temp_dir(), 'model_dir')\n    save_est, input_fn = self._make_estimator(save_model_dir)\n\n    save_est.train(input_fn, steps=3)\n\n    model = SubclassedModel()\n    optimizer = tf_keras.optimizers.Adam(0.01)\n    checkpoint = util.Checkpoint(\n        step=tf.Variable(0, dtype=tf.dtypes.int64),\n        optimizer=optimizer,\n        model=model)\n    status = checkpoint.restore(tf.train.latest_checkpoint(save_model_dir))\n    self.assertEqual(3, self.evaluate(checkpoint.step))\n    with tf.GradientTape() as tape:\n      output = model(tf.constant([[1.]]))\n      loss = tf.math.reduce_sum(output)\n    variables = model.trainable_variables\n    gradients = tape.gradient(loss, variables)\n    optimizer.apply_gradients(zip(gradients, variables))\n    status.assert_consumed()\n\n    # The optimizer uses this for some reason...\n    tf_keras.backend.clear_session()\n\n    load_model_dir = os.path.join(self.get_temp_dir(), 'load_model_dir/')\n    checkpoint.step.assign(40)\n    checkpoint.model.dense_two.bias.assign([13.])\n    checkpoint.save(load_model_dir)\n    load_est, input_fn = self._make_estimator(load_model_dir)\n    predictions = load_est.predict(input_fn)\n    predictions = next(predictions)\n    self.assertAllClose([13.], predictions['bias'])\n    self.assertEqual(40, predictions['step'])\n\n  def testSavedModelExport(self):\n    model_dir = os.path.join(self.get_temp_dir(), 'estimator_train_dir')\n    estimator, input_fn = self._make_estimator(model_dir)\n    estimator.train(input_fn, steps=1)  # Train to generate a checkpoint.\n\n    export_dir_base = os.path.join(self.get_temp_dir(), 'estimator_export_dir')\n    export_dir = estimator.export_saved_model(export_dir_base,\n                                              _serving_input_receiver_fn)\n\n    # Check the saved model loads and simple inference runs.\n    model = tf.compat.v2.saved_model.load(export_dir)\n    model.signatures['serving_default'](tf.constant([[1.]]))\n\n\nif __name__ == '__main__':\n  tf.compat.v1.enable_eager_execution()\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/run_config.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Environment configuration object for Estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport copy\nimport json\nimport os\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.core.protobuf import rewriter_config_pb2\nfrom tensorflow.python.distribute import estimator_training as distribute_coordinator_training\nfrom tensorflow.python.util import function_utils\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\n\n_USE_DEFAULT = object()\n_VALID_DEVICE_FN_ARGS = set(['op'])\n\n# A list of the property names in RunConfig that the user is allowed to change.\n_DEFAULT_REPLACEABLE_LIST = [\n    'model_dir', 'tf_random_seed', 'save_summary_steps',\n    'save_checkpoints_steps', 'save_checkpoints_secs', 'session_config',\n    'keep_checkpoint_max', 'keep_checkpoint_every_n_hours',\n    'log_step_count_steps', 'train_distribute', 'device_fn', 'protocol',\n    'eval_distribute', 'experimental_distribute',\n    'experimental_max_worker_delay_secs', 'session_creation_timeout_secs',\n    'checkpoint_save_graph_def'\n]\n\n_SAVE_CKPT_ERR = (\n    '`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set.')\n\n_TF_CONFIG_ENV = 'TF_CONFIG'\n_TASK_ENV_KEY = 'task'\n_TASK_TYPE_KEY = 'type'\n_TASK_ID_KEY = 'index'\n_CLUSTER_KEY = 'cluster'\n_SERVICE_KEY = 'service'\n_SESSION_MASTER_KEY = 'session_master'\n_EVAL_SESSION_MASTER_KEY = 'eval_session_master'\n_MODEL_DIR_KEY = 'model_dir'\n_LOCAL_MASTER = ''\n_GRPC_SCHEME = 'grpc://'\n\n\ndef _get_session_master(cluster_spec, task_type, task_id, tf_config):\n  \"\"\"Returns the appropriate address for TensorFlow master.\n\n  The order of precedence to determine the TF session master is as follows:\n  1. If `tf_session_master` is set in TF_CONFIG environment variable, takes it.\n  2. If the cluster has only one node, returns empty string ''.\n  3. Returns the grpc address according to the task type and id in the cluster.\n     This is between-graph replication.\n\n  Note: task_type and task_id must be validated. Typically, validated using\n  `_validate_task_type_and_task_id`.\n\n  Args:\n    cluster_spec: A `ClusterSpec` instance.\n    task_type: String. Task type for current node.\n    task_id: Int. Task id for current node.\n    tf_config: Dict. Python dict for the TF_CONFIG environment variable.\n\n  Raises:\n    RuntimeError: If `cluster_spec` is not set.\n\n  \"\"\"\n  if _SESSION_MASTER_KEY in tf_config:\n    return tf_config[_SESSION_MASTER_KEY]\n\n  if not cluster_spec:\n    raise RuntimeError('Internal error: `_get_session_master` '\n                       'does not expect empty cluster_spec.')\n\n  jobs = cluster_spec.jobs\n\n  # If there is only one node in the cluster, do things locally by setting\n  # master to ''.  If a service or user sets TF_CONFIG with a single node, it's\n  # more performant to use a direct master rather than an RPC service.\n  if len(jobs) == 1 and len(cluster_spec.job_tasks(jobs[0])) == 1:\n    return _LOCAL_MASTER\n\n  # Lookup the master in cluster_spec using task_type and task_id,\n  # if possible.\n  addresses = cluster_spec.job_tasks(task_type)\n  return _GRPC_SCHEME + addresses[task_id]\n\n\ndef _get_eval_session_master(task_type, tf_config):\n  \"\"\"Returns the appropriate address for TensorFlow evaluation master.\"\"\"\n  if task_type == TaskType.EVALUATOR:\n    return tf_config.get(_EVAL_SESSION_MASTER_KEY, _LOCAL_MASTER)\n\n  return _LOCAL_MASTER\n\n\ndef _count_ps(cluster_spec):\n  \"\"\"Counts the number of parameter servers in cluster_spec.\"\"\"\n  if not cluster_spec:\n    raise RuntimeError(\n        'Internal error: `_count_ps` does not expect empty cluster_spec.')\n\n  return len(cluster_spec.as_dict().get(TaskType.PS, []))\n\n\ndef _count_worker(cluster_spec, chief_task_type):\n  \"\"\"Counts the number of workers (including chief) in cluster_spec.\"\"\"\n  if not cluster_spec:\n    raise RuntimeError(\n        'Internal error: `_count_worker` does not expect empty cluster_spec.')\n\n  return (len(cluster_spec.as_dict().get(TaskType.WORKER, [])) +\n          len(cluster_spec.as_dict().get(chief_task_type, [])))\n\n\ndef _validate_service(service):\n  \"\"\"Validates the service key.\"\"\"\n  if service is not None and not isinstance(service, dict):\n    raise TypeError(\n        'If \"service\" is set in TF_CONFIG, it must be a dict. Given %s' %\n        type(service))\n  return service\n\n\ndef _validate_task_type_and_task_id(cluster_spec, task_env, chief_task_type):\n  \"\"\"Validates the task type and index in `task_env` according to cluster.\"\"\"\n  if chief_task_type not in cluster_spec.jobs:\n    raise ValueError(\n        'If \"cluster\" is set in TF_CONFIG, it must have one \"%s\" node.' %\n        chief_task_type)\n  if len(cluster_spec.job_tasks(chief_task_type)) > 1:\n    raise ValueError(\n        'The \"cluster\" in TF_CONFIG must have only one \"%s\" node.' %\n        chief_task_type)\n\n  task_type = task_env.get(_TASK_TYPE_KEY, None)\n  task_id = task_env.get(_TASK_ID_KEY, None)\n\n  if not task_type:\n    raise ValueError('If \"cluster\" is set in TF_CONFIG, task type must be set.')\n  if task_id is None:\n    raise ValueError(\n        'If \"cluster\" is set in TF_CONFIG, task index must be set.')\n\n  task_id = int(task_id)\n\n  # Check the task id bounds. Upper bound is not necessary as\n  # - for evaluator, there is no upper bound.\n  # - for non-evaluator, task id is upper bounded by the number of jobs in\n  # cluster spec, which will be checked later (when retrieving the `master`)\n  if task_id < 0:\n    raise ValueError('Task index must be non-negative number.')\n\n  # Evaluator is not part of the training cluster.\n  if task_type == TaskType.EVALUATOR:\n    return task_type, task_id\n\n  if task_type not in cluster_spec.jobs:\n    raise ValueError(\n        '%s is not a valid task_type in the cluster_spec:\\n'\n        '%s\\n\\n'\n        'Note that these values may be coming from the TF_CONFIG environment '\n        'variable.' % (task_type, cluster_spec))\n  addresses = cluster_spec.job_tasks(task_type)\n  if not 0 <= task_id < len(addresses):\n    raise ValueError(\n        '%d is not a valid task_id for task_type %s in the cluster_spec:\\n'\n        '%s\\n\\n'\n        'Note that these values may be coming from the TF_CONFIG environment '\n        'variable.' % (task_id, task_type, cluster_spec))\n\n  return task_type, task_id\n\n\ndef _get_global_id_in_cluster(cluster_spec, task_type, task_id,\n                              chief_task_type):\n  \"\"\"Returns the global id in cluster.\"\"\"\n  # Note: This is implementation details, which user should not rely on.\n  # The first id is 0, which is always for the `chief` node. All other nodes,\n  # except `ps`, are ordered alphabetical based on task type (alphabetically)\n  # and task id (ascendingly). `ps` are ordered last.\n\n  # Sort task names in cluster\n  task_type_ordered_list = [chief_task_type]\n  task_type_ordered_list.extend([\n      t for t in sorted(cluster_spec.jobs)\n      if t != chief_task_type and t != TaskType.PS\n  ])\n  if TaskType.PS in cluster_spec.jobs:\n    task_type_ordered_list.append(TaskType.PS)\n\n  next_global_id = 0\n  for t in task_type_ordered_list:\n    if t == task_type:\n      return next_global_id + task_id\n    next_global_id += len(cluster_spec.job_tasks(t))\n\n  # This should never happen.\n  raise RuntimeError('Internal Error: `task_type` ({}) is not in '\n                     'cluster_spec ({}).'.format(task_type, cluster_spec))\n\n\ndef _validate_save_ckpt_with_replaced_keys(new_copy, replaced_keys):\n  \"\"\"Validates the save ckpt properties.\"\"\"\n  # Ensure one (and only one) of save_steps and save_secs is not None.\n  # Also, if user sets one save ckpt property, say steps, the other one (secs)\n  # should be set as None to improve usability.\n\n  save_steps = new_copy.save_checkpoints_steps\n  save_secs = new_copy.save_checkpoints_secs\n\n  if ('save_checkpoints_steps' in replaced_keys and\n      'save_checkpoints_secs' in replaced_keys):\n    # If user sets both properties explicitly, we need to error out if both\n    # are set or neither of them are set.\n    if save_steps is not None and save_secs is not None:\n      raise ValueError(_SAVE_CKPT_ERR)\n  elif 'save_checkpoints_steps' in replaced_keys and save_steps is not None:\n    new_copy._save_checkpoints_secs = None  # pylint: disable=protected-access\n  elif 'save_checkpoints_secs' in replaced_keys and save_secs is not None:\n    new_copy._save_checkpoints_steps = None  # pylint: disable=protected-access\n\n\ndef _validate_properties(run_config):\n  \"\"\"Validates the properties.\"\"\"\n\n  def _validate(property_name, cond, message):\n    property_value = getattr(run_config, property_name)\n    if property_value is not None and not cond(property_value):\n      raise ValueError(message)\n\n  def _validate_delay(delay):\n    \"\"\"Check that delay is an integer value.\n\n    Since this has to work for both Python2 and Python3 and PEP237 defines long\n    to be basically int, we cannot just use a lambda function.\n    \"\"\"\n    try:\n      return isinstance(delay, (int, long))\n    except NameError:\n      # PEP237 redefines long to int for Python3\n      return isinstance(delay, int)\n\n  _validate(\n      'model_dir', lambda dir: dir, message='model_dir should be non-empty')\n\n  _validate(\n      'save_summary_steps',\n      lambda steps: steps >= 0,\n      message='save_summary_steps should be >= 0')\n\n  _validate(\n      'save_checkpoints_steps',\n      lambda steps: steps >= 0,\n      message='save_checkpoints_steps should be >= 0')\n  _validate(\n      'save_checkpoints_secs',\n      lambda secs: secs >= 0,\n      message='save_checkpoints_secs should be >= 0')\n\n  _validate(\n      'session_config',\n      lambda sc: isinstance(sc, tf.compat.v1.ConfigProto),\n      message='session_config must be instance of ConfigProto')\n\n  _validate(\n      'keep_checkpoint_max',\n      lambda keep_max: keep_max >= 0,\n      message='keep_checkpoint_max should be >= 0')\n  _validate(\n      'keep_checkpoint_every_n_hours',\n      lambda keep_hours: keep_hours > 0,\n      message='keep_checkpoint_every_n_hours should be > 0')\n  _validate(\n      'log_step_count_steps',\n      lambda num_steps: num_steps > 0,\n      message='log_step_count_steps should be > 0')\n\n  _validate(\n      'tf_random_seed',\n      lambda seed: isinstance(seed, six.integer_types),\n      message='tf_random_seed must be integer.')\n\n  _validate(\n      'experimental_max_worker_delay_secs',\n      _validate_delay,\n      message='experimental_max_worker_delay_secs must be an integer if'\n      ' set.')\n  _validate(\n      'session_creation_timeout_secs',\n      lambda timeout_secs: timeout_secs > 0,\n      message='session_creation_timeout_secs should be > 0')\n\n  _validate(\n      'device_fn',\n      lambda device_fn: six.callable(device_fn) and set(\n          function_utils.fn_args(device_fn)) == _VALID_DEVICE_FN_ARGS,\n      message='device_fn must be callable with exactly'\n      ' one argument \"op\".')\n\n  _validate(\n      'protocol',\n      lambda protocol: protocol in (None, 'grpc', 'grpc+verbs'),\n      message='protocol should be grpc or grpc+verbs')\n\n\ndef get_default_session_config():\n  \"\"\"Returns tf.ConfigProto instance.\"\"\"\n\n  rewrite_opts = rewriter_config_pb2.RewriterConfig(\n      meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)\n  graph_opts = tf.compat.v1.GraphOptions(rewrite_options=rewrite_opts)\n\n  return tf.compat.v1.ConfigProto(\n      allow_soft_placement=True, graph_options=graph_opts)\n\n\nclass TaskType(object):\n  MASTER = 'master'\n  PS = 'ps'\n  WORKER = 'worker'\n  CHIEF = 'chief'\n  EVALUATOR = 'evaluator'\n\n\n@estimator_export('estimator.RunConfig')\nclass RunConfig(object):\n  \"\"\"This class specifies the configurations for an `Estimator` run.\"\"\"\n\n  def __init__(self,\n               model_dir=None,\n               tf_random_seed=None,\n               save_summary_steps=100,\n               save_checkpoints_steps=_USE_DEFAULT,\n               save_checkpoints_secs=_USE_DEFAULT,\n               session_config=None,\n               keep_checkpoint_max=5,\n               keep_checkpoint_every_n_hours=10000,\n               log_step_count_steps=100,\n               train_distribute=None,\n               device_fn=None,\n               protocol=None,\n               eval_distribute=None,\n               experimental_distribute=None,\n               experimental_max_worker_delay_secs=None,\n               session_creation_timeout_secs=7200,\n               checkpoint_save_graph_def=True):\n    \"\"\"Constructs a RunConfig.\n\n    All distributed training related properties `cluster_spec`, `is_chief`,\n    `master` , `num_worker_replicas`, `num_ps_replicas`, `task_id`, and\n    `task_type` are set based on the `TF_CONFIG` environment variable, if the\n    pertinent information is present. The `TF_CONFIG` environment variable is a\n    JSON object with attributes: `cluster` and `task`.\n\n    `cluster` is a JSON serialized version of `ClusterSpec`'s Python dict from\n    `server_lib.py`, mapping task types (usually one of the `TaskType` enums) to\n    a list of task addresses.\n\n    `task` has two attributes: `type` and `index`, where `type` can be any of\n    the task types in `cluster`. When `TF_CONFIG` contains said information,\n    the following properties are set on this class:\n\n    * `cluster_spec` is parsed from `TF_CONFIG['cluster']`. Defaults to {}. If\n      present, must have one and only one node in the `chief` attribute of\n      `cluster_spec`.\n    * `task_type` is set to `TF_CONFIG['task']['type']`. Must set if\n      `cluster_spec` is present; must be `worker` (the default value) if\n      `cluster_spec` is not set.\n    * `task_id` is set to `TF_CONFIG['task']['index']`. Must set if\n      `cluster_spec` is present; must be 0 (the default value) if\n      `cluster_spec` is not set.\n    * `master` is determined by looking up `task_type` and `task_id` in the\n      `cluster_spec`. Defaults to ''.\n    * `num_ps_replicas` is set by counting the number of nodes listed\n      in the `ps` attribute of `cluster_spec`. Defaults to 0.\n    * `num_worker_replicas` is set by counting the number of nodes listed\n      in the `worker` and `chief` attributes of `cluster_spec`. Defaults to 1.\n    * `is_chief` is determined based on `task_type` and `cluster`.\n\n    There is a special node with `task_type` as `evaluator`, which is not part\n    of the (training) `cluster_spec`. It handles the distributed evaluation job.\n\n    Example of non-chief node:\n    ```\n      cluster = {'chief': ['host0:2222'],\n                 'ps': ['host1:2222', 'host2:2222'],\n                 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}\n      os.environ['TF_CONFIG'] = json.dumps(\n          {'cluster': cluster,\n           'task': {'type': 'worker', 'index': 1}})\n      config = RunConfig()\n      assert config.master == 'host4:2222'\n      assert config.task_id == 1\n      assert config.num_ps_replicas == 2\n      assert config.num_worker_replicas == 4\n      assert config.cluster_spec == server_lib.ClusterSpec(cluster)\n      assert config.task_type == 'worker'\n      assert not config.is_chief\n    ```\n\n    Example of chief node:\n    ```\n      cluster = {'chief': ['host0:2222'],\n                 'ps': ['host1:2222', 'host2:2222'],\n                 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}\n      os.environ['TF_CONFIG'] = json.dumps(\n          {'cluster': cluster,\n           'task': {'type': 'chief', 'index': 0}})\n      config = RunConfig()\n      assert config.master == 'host0:2222'\n      assert config.task_id == 0\n      assert config.num_ps_replicas == 2\n      assert config.num_worker_replicas == 4\n      assert config.cluster_spec == server_lib.ClusterSpec(cluster)\n      assert config.task_type == 'chief'\n      assert config.is_chief\n    ```\n\n    Example of evaluator node (evaluator is not part of training cluster):\n    ```\n      cluster = {'chief': ['host0:2222'],\n                 'ps': ['host1:2222', 'host2:2222'],\n                 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}\n      os.environ['TF_CONFIG'] = json.dumps(\n          {'cluster': cluster,\n           'task': {'type': 'evaluator', 'index': 0}})\n      config = RunConfig()\n      assert config.master == ''\n      assert config.evaluator_master == ''\n      assert config.task_id == 0\n      assert config.num_ps_replicas == 0\n      assert config.num_worker_replicas == 0\n      assert config.cluster_spec == {}\n      assert config.task_type == 'evaluator'\n      assert not config.is_chief\n    ```\n\n    N.B.: If `save_checkpoints_steps` or `save_checkpoints_secs` is set,\n    `keep_checkpoint_max` might need to be adjusted accordingly, especially in\n    distributed training. For example, setting `save_checkpoints_secs` as 60\n    without adjusting `keep_checkpoint_max` (defaults to 5) leads to situation\n    that checkpoint would be garbage collected after 5 minutes. In distributed\n    training, the evaluation job starts asynchronously and might fail to load or\n    find the checkpoint due to race condition.\n\n    Args:\n      model_dir: directory where model parameters, graph, etc are saved. If\n        `PathLike` object, the path will be resolved. If `None`, will use a\n        default value set by the Estimator.\n      tf_random_seed: Random seed for TensorFlow initializers. Setting this\n        value allows consistency between reruns.\n      save_summary_steps: Save summaries every this many steps.\n      save_checkpoints_steps: Save checkpoints every this many steps. Can not be\n        specified with `save_checkpoints_secs`.\n      save_checkpoints_secs: Save checkpoints every this many seconds. Can not\n        be specified with `save_checkpoints_steps`. Defaults to 600 seconds if\n        both `save_checkpoints_steps` and `save_checkpoints_secs` are not set in\n        constructor.  If both `save_checkpoints_steps` and\n        `save_checkpoints_secs` are `None`, then checkpoints are disabled.\n      session_config: a ConfigProto used to set session parameters, or `None`.\n      keep_checkpoint_max: The maximum number of recent checkpoint files to\n        keep. As new files are created, older files are deleted. If `None` or 0,\n        all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent\n        checkpoint files are kept). If a saver is passed to the estimator, this\n        argument will be ignored.\n      keep_checkpoint_every_n_hours: Number of hours between each checkpoint to\n        be saved. The default value of 10,000 hours effectively disables the\n        feature.\n      log_step_count_steps: The frequency, in number of global steps, that the\n        global step and the loss will be logged during training.  Also controls\n        the frequency that the global steps / s will be logged (and written to\n        summary) during training.\n      train_distribute: An optional instance of `tf.distribute.Strategy`. If\n        specified, then Estimator will distribute the user's model during\n        training, according to the policy specified by that strategy. Setting\n        `experimental_distribute.train_distribute` is preferred.\n      device_fn: A callable invoked for every `Operation` that takes the\n        `Operation` and returns the device string. If `None`, defaults to the\n        device function returned by `tf.train.replica_device_setter` with\n        round-robin strategy.\n      protocol: An optional argument which specifies the protocol used when\n        starting server. `None` means default to grpc.\n      eval_distribute: An optional instance of `tf.distribute.Strategy`. If\n        specified, then Estimator will distribute the user's model during\n        evaluation, according to the policy specified by that strategy. Setting\n        `experimental_distribute.eval_distribute` is preferred.\n      experimental_distribute: An optional\n        `tf.contrib.distribute.DistributeConfig` object specifying\n        DistributionStrategy-related configuration. The `train_distribute` and\n        `eval_distribute` can be passed as parameters to `RunConfig` or set in\n        `experimental_distribute` but not both.\n      experimental_max_worker_delay_secs: An optional integer specifying the\n        maximum time a worker should wait before starting. By default, workers\n        are started at staggered times, with each worker being delayed by up to\n        60 seconds. This is intended to reduce the risk of divergence, which can\n        occur when many workers simultaneously update the weights of a randomly\n        initialized model. Users who warm-start their models and train them for\n        short durations (a few minutes or less) should consider reducing this\n        default to improve training times.\n      session_creation_timeout_secs: Max time workers should wait for a session\n        to become available (on initialization or when recovering a session)\n        with MonitoredTrainingSession. Defaults to 7200 seconds, but users may\n        want to set a lower value to detect problems with variable / session\n        (re)-initialization more quickly.\n      checkpoint_save_graph_def: Whether to save the GraphDef and MetaGraphDef\n        to `checkpoint_dir`. The GraphDef is saved after the session is created\n        as `graph.pbtxt`. MetaGraphDefs are saved out for every checkpoint as\n        `model.ckpt-*.meta`.\n\n    Raises:\n      ValueError: If both `save_checkpoints_steps` and `save_checkpoints_secs`\n      are set.\n    \"\"\"\n    if (save_checkpoints_steps == _USE_DEFAULT and\n        save_checkpoints_secs == _USE_DEFAULT):\n      save_checkpoints_steps = None\n      save_checkpoints_secs = 600\n    elif save_checkpoints_secs == _USE_DEFAULT:\n      save_checkpoints_secs = None\n    elif save_checkpoints_steps == _USE_DEFAULT:\n      save_checkpoints_steps = None\n    elif (save_checkpoints_steps is not None and\n          save_checkpoints_secs is not None):\n      raise ValueError(_SAVE_CKPT_ERR)\n\n    self._verify_strategy_compatibility(train_distribute, eval_distribute)\n\n    tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))\n    if tf_config:\n      tf.compat.v1.logging.info('TF_CONFIG environment variable: %s', tf_config)\n\n    model_dir = _get_model_dir(tf_config, path_to_str(model_dir))\n\n    RunConfig._replace(\n        self,\n        allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,\n        model_dir=model_dir,\n        tf_random_seed=tf_random_seed,\n        save_summary_steps=save_summary_steps,\n        save_checkpoints_steps=save_checkpoints_steps,\n        save_checkpoints_secs=save_checkpoints_secs,\n        session_config=session_config,\n        keep_checkpoint_max=keep_checkpoint_max,\n        keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,\n        log_step_count_steps=log_step_count_steps,\n        train_distribute=train_distribute,\n        device_fn=device_fn,\n        protocol=protocol,\n        eval_distribute=eval_distribute,\n        experimental_distribute=experimental_distribute,\n        experimental_max_worker_delay_secs=experimental_max_worker_delay_secs,\n        session_creation_timeout_secs=session_creation_timeout_secs,\n        checkpoint_save_graph_def=checkpoint_save_graph_def)\n\n    # TODO(frankchn,priyag): Eventually use distributed coordinator for TPUs.\n    if ((train_distribute and\n         not train_distribute.__class__.__name__.startswith('TPUStrategy')) or\n        (eval_distribute and\n         not eval_distribute.__class__.__name__.startswith('TPUStrategy')) or\n        experimental_distribute):\n      tf.compat.v1.logging.info(\n          'Initializing RunConfig with distribution strategies.')\n      distribute_coordinator_training.init_run_config(self, tf_config)\n    else:\n      self._init_distributed_setting_from_environment_var(tf_config)\n      self._maybe_overwrite_session_config_for_distributed_training()\n\n  def _verify_strategy_compatibility(self, train_distribute, eval_distribute):\n    if ((train_distribute is not None and train_distribute.__class__ ==\n         tf.compat.v2.distribute.experimental.ParameterServerStrategy) or\n        (eval_distribute is not None and eval_distribute.__class__ ==\n         tf.compat.v2.distribute.experimental.ParameterServerStrategy)):\n      raise ValueError('Please use `tf.compat.v1.distribute.experimental.Param'\n                       'eterServerStrategy` for parameter server strategy with '\n                       'estimator.')\n\n  def _maybe_overwrite_session_config_for_distributed_training(self):\n    \"\"\"Overwrites the session_config for distributed training.\n\n    The default overwrite is optimized for between-graph training. Subclass\n    should override this method if necessary.\n    \"\"\"\n    # Get session_config only for between-graph distributed mode (cluster_spec\n    # is present).\n    if not self._session_config and self._cluster_spec:\n      RunConfig._replace(\n          self,\n          allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,\n          session_config=self._get_default_session_config_distributed())\n\n  def _get_default_session_config_distributed(self):\n    \"\"\"Returns None or tf.ConfigProto instance with default device_filters set.\n\n    Device filters are set such that chief/master and worker communicates with\n    only ps. session_config=None for evaluators or any other TaskType.\n    \"\"\"\n\n    rewrite_opts = rewriter_config_pb2.RewriterConfig(\n        meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)\n    graph_opts = tf.compat.v1.GraphOptions(rewrite_options=rewrite_opts)\n\n    device_filters = None\n    if self._task_type == TaskType.MASTER:\n      device_filters = ['/job:ps', '/job:master']\n    elif self._task_type == TaskType.CHIEF:\n      device_filters = ['/job:ps', '/job:chief']\n    elif self._task_type == TaskType.WORKER:\n      device_filters = ['/job:ps', '/job:worker/task:%d' % self._task_id]\n    elif self._task_type == TaskType.PS:\n      device_filters = ['/job:ps', '/job:worker', '/job:chief', '/job:master']\n    else:\n      # If the task_type is `EVALUATOR` or something other than the ones in\n      # TaskType then don't set any device filters.\n      return None\n\n    return tf.compat.v1.ConfigProto(\n        allow_soft_placement=True,\n        graph_options=graph_opts,\n        device_filters=device_filters)\n\n  def _init_distributed_setting_from_environment_var(self, tf_config):\n    \"\"\"Initialize distributed properties based on `tf_config`.\"\"\"\n\n    self._service = _validate_service(tf_config.get(_SERVICE_KEY))\n    self._cluster_spec = tf.train.ClusterSpec(tf_config.get(_CLUSTER_KEY, {}))\n    task_env = tf_config.get(_TASK_ENV_KEY, {})\n\n    if self._cluster_spec and TaskType.MASTER in self._cluster_spec.jobs:\n      return self._init_distributed_setting_from_environment_var_with_master(\n          tf_config)\n\n    if self._cluster_spec:\n      # Distributed mode.\n      self._task_type, self._task_id = _validate_task_type_and_task_id(\n          self._cluster_spec, task_env, TaskType.CHIEF)\n\n      self._evaluation_master = _get_eval_session_master(\n          self._task_type, tf_config)\n\n      if self._task_type != TaskType.EVALUATOR:\n        self._master = _get_session_master(self._cluster_spec, self._task_type,\n                                           self._task_id, tf_config)\n        self._num_ps_replicas = _count_ps(self._cluster_spec)\n        self._num_worker_replicas = _count_worker(\n            self._cluster_spec, chief_task_type=TaskType.CHIEF)\n        self._global_id_in_cluster = _get_global_id_in_cluster(\n            self._cluster_spec,\n            self._task_type,\n            self._task_id,\n            chief_task_type=TaskType.CHIEF)\n      else:\n        # Evaluator is not part of the training cluster.\n        self._cluster_spec = tf.train.ClusterSpec({})\n        self._master = _LOCAL_MASTER\n        self._num_ps_replicas = 0\n        self._num_worker_replicas = 0\n        self._global_id_in_cluster = None  # undefined\n\n      self._is_chief = self._task_type == TaskType.CHIEF\n    else:\n      # Local mode.\n      self._task_type = task_env.get(_TASK_TYPE_KEY, TaskType.WORKER)\n      self._task_id = int(task_env.get(_TASK_ID_KEY, 0))\n      self._global_id_in_cluster = 0\n\n      if self._task_type != TaskType.WORKER:\n        raise ValueError(\n            'If \"cluster\" is not set in TF_CONFIG, task type must be WORKER.')\n      if self._task_id != 0:\n        raise ValueError(\n            'If \"cluster\" is not set in TF_CONFIG, task index must be 0.')\n\n      self._master = tf_config.get(_SESSION_MASTER_KEY, _LOCAL_MASTER)\n      self._evaluation_master = tf_config.get(_EVAL_SESSION_MASTER_KEY,\n                                              _LOCAL_MASTER)\n      self._is_chief = True\n      self._num_ps_replicas = 0\n      self._num_worker_replicas = 1\n\n  def _init_distributed_setting_from_environment_var_with_master(\n      self, tf_config):\n    \"\"\"Initialize distributed properties for legacy cluster with `master`.\"\"\"\n    # There is no tech reason, why user cannot have chief and master in the same\n    # cluster, but it is super confusing (which is really the chief?). So, block\n    # this case.\n    if TaskType.CHIEF in self._cluster_spec.jobs:\n      raise ValueError('If `master` node exists in `cluster`, job '\n                       '`chief` is not supported.')\n\n    task_env = tf_config.get(_TASK_ENV_KEY, {})\n\n    self._task_type, self._task_id = _validate_task_type_and_task_id(\n        self._cluster_spec, task_env, TaskType.MASTER)\n\n    if self._task_type == TaskType.EVALUATOR:\n      raise ValueError('If `master` node exists in `cluster`, task_type '\n                       '`evaluator` is not supported.')\n\n    self._global_id_in_cluster = _get_global_id_in_cluster(\n        self._cluster_spec,\n        self._task_type,\n        self._task_id,\n        chief_task_type=TaskType.MASTER)\n\n    self._master = _get_session_master(self._cluster_spec, self._task_type,\n                                       self._task_id, tf_config)\n    self._evaluation_master = _get_eval_session_master(self._task_type,\n                                                       tf_config)\n    self._num_ps_replicas = _count_ps(self._cluster_spec)\n    self._num_worker_replicas = _count_worker(\n        self._cluster_spec, chief_task_type=TaskType.MASTER)\n\n    self._is_chief = self._task_type == TaskType.MASTER\n\n  @property\n  def cluster_spec(self):\n    return self._cluster_spec\n\n  @property\n  def device_fn(self):\n    \"\"\"Returns the device_fn.\n\n    If device_fn is not `None`, it overrides the default\n    device function used in `Estimator`.\n    Otherwise the default one is used.\n    \"\"\"\n    return self._device_fn\n\n  @property\n  def evaluation_master(self):\n    return self._evaluation_master\n\n  @property\n  def is_chief(self):\n    return self._is_chief\n\n  @property\n  def master(self):\n    return self._master\n\n  @property\n  def num_ps_replicas(self):\n    return self._num_ps_replicas\n\n  @property\n  def num_worker_replicas(self):\n    return self._num_worker_replicas\n\n  @property\n  def task_id(self):\n    return self._task_id\n\n  @property\n  def global_id_in_cluster(self):\n    \"\"\"The global id in the training cluster.\n\n    All global ids in the training cluster are assigned from an increasing\n    sequence of consecutive integers. The first id is 0.\n\n    Note: Task id (the property field `task_id`) is tracking the index of the\n    node among all nodes with the SAME task type. For example, given the cluster\n    definition as follows:\n\n    ```\n      cluster = {'chief': ['host0:2222'],\n                 'ps': ['host1:2222', 'host2:2222'],\n                 'worker': ['host3:2222', 'host4:2222', 'host5:2222']}\n    ```\n\n    Nodes with task type `worker` can have id 0, 1, 2.  Nodes with task type\n    `ps` can have id, 0, 1. So, `task_id` is not unique, but the pair\n    (`task_type`, `task_id`) can uniquely determine a node in the cluster.\n\n    Global id, i.e., this field, is tracking the index of the node among ALL\n    nodes in the cluster. It is uniquely assigned.  For example, for the cluster\n    spec given above, the global ids are assigned as:\n    ```\n      task_type  | task_id  |  global_id\n      --------------------------------\n      chief      | 0        |  0\n      worker     | 0        |  1\n      worker     | 1        |  2\n      worker     | 2        |  3\n      ps         | 0        |  4\n      ps         | 1        |  5\n    ```\n\n    Returns:\n      An integer id.\n    \"\"\"\n    return self._global_id_in_cluster\n\n  @property\n  def experimental_max_worker_delay_secs(self):\n    return self._experimental_max_worker_delay_secs\n\n  @property\n  def task_type(self):\n    return self._task_type\n\n  @property\n  def tf_random_seed(self):\n    return self._tf_random_seed\n\n  @property\n  def save_summary_steps(self):\n    return self._save_summary_steps\n\n  @property\n  def save_checkpoints_secs(self):\n    return self._save_checkpoints_secs\n\n  @property\n  def session_config(self):\n    return self._session_config\n\n  @property\n  def save_checkpoints_steps(self):\n    return self._save_checkpoints_steps\n\n  @property\n  def checkpoint_save_graph_def(self):\n    return self._checkpoint_save_graph_def\n\n  @property\n  def keep_checkpoint_max(self):\n    return self._keep_checkpoint_max\n\n  @property\n  def session_creation_timeout_secs(self):\n    return self._session_creation_timeout_secs\n\n  @property\n  def keep_checkpoint_every_n_hours(self):\n    return self._keep_checkpoint_every_n_hours\n\n  @property\n  def log_step_count_steps(self):\n    return self._log_step_count_steps\n\n  @property\n  def model_dir(self):\n    return self._model_dir\n\n  @property\n  def service(self):\n    \"\"\"Returns the platform defined (in TF_CONFIG) service dict.\"\"\"\n    return self._service\n\n  @property\n  def train_distribute(self):\n    \"\"\"Optional `tf.distribute.Strategy` for training.\"\"\"\n    return self._train_distribute\n\n  @property\n  def eval_distribute(self):\n    \"\"\"Optional `tf.distribute.Strategy` for evaluation.\"\"\"\n    return self._eval_distribute\n\n  @property\n  def protocol(self):\n    \"\"\"Returns the optional protocol value.\"\"\"\n    return self._protocol\n\n  def replace(self, **kwargs):\n    \"\"\"Returns a new instance of `RunConfig` replacing specified properties.\n\n    Only the properties in the following list are allowed to be replaced:\n\n      - `model_dir`,\n      - `tf_random_seed`,\n      - `save_summary_steps`,\n      - `save_checkpoints_steps`,\n      - `save_checkpoints_secs`,\n      - `session_config`,\n      - `keep_checkpoint_max`,\n      - `keep_checkpoint_every_n_hours`,\n      - `log_step_count_steps`,\n      - `train_distribute`,\n      - `device_fn`,\n      - `protocol`.\n      - `eval_distribute`,\n      - `experimental_distribute`,\n      - `experimental_max_worker_delay_secs`,\n\n    In addition, either `save_checkpoints_steps` or `save_checkpoints_secs`\n    can be set (should not be both).\n\n    Args:\n      **kwargs: keyword named properties with new values.\n\n    Raises:\n      ValueError: If any property name in `kwargs` does not exist or is not\n        allowed to be replaced, or both `save_checkpoints_steps` and\n        `save_checkpoints_secs` are set.\n\n    Returns:\n      a new instance of `RunConfig`.\n    \"\"\"\n    return RunConfig._replace(\n        copy.deepcopy(self),\n        allowed_properties_list=_DEFAULT_REPLACEABLE_LIST,\n        **kwargs)\n\n  @staticmethod\n  def _replace(config, allowed_properties_list=None, **kwargs):\n    \"\"\"See `replace`.\n\n    N.B.: This implementation assumes that for key named \"foo\", the underlying\n    property the RunConfig holds is \"_foo\" (with one leading underscore).\n\n    Args:\n      config: The RunConfig to replace the values of.\n      allowed_properties_list: The property name list allowed to be replaced.\n      **kwargs: keyword named properties with new values.\n\n    Raises:\n      ValueError: If any property name in `kwargs` does not exist or is not\n        allowed to be replaced, or both `save_checkpoints_steps` and\n        `save_checkpoints_secs` are set.\n\n    Returns:\n      a new instance of `RunConfig`.\n    \"\"\"\n\n    allowed_properties_list = allowed_properties_list or []\n\n    for key, new_value in six.iteritems(kwargs):\n      if key in allowed_properties_list:\n        setattr(config, '_' + key, new_value)\n        continue\n\n      raise ValueError(\n          'Replacing {} is not supported. Allowed properties are {}.'.format(\n              key, allowed_properties_list))\n\n    _validate_save_ckpt_with_replaced_keys(config, kwargs.keys())\n    _validate_properties(config)\n    return config\n\n\ndef _get_model_dir(tf_config, model_dir):\n  \"\"\"Returns `model_dir` based user provided `tf_config` or `model_dir`.\"\"\"\n  # pylint: disable=g-explicit-bool-comparison\n\n  # Empty string is treated as False in Python condition check, which triggers\n  # some confusing error messages. For example, 'a or b' returns None if a is ''\n  # and b is None. `None` is allowed for model_dir but '' is not allowed. Here,\n  # explicitly check empty string to provide clear error message.\n  if model_dir == '':\n    raise ValueError('model_dir should be non-empty.')\n\n  model_dir_in_tf_config = tf_config.get('model_dir')\n  if model_dir_in_tf_config == '':\n    raise ValueError('model_dir in TF_CONFIG should be non-empty.')\n\n  if model_dir_in_tf_config:\n    if model_dir and model_dir_in_tf_config != model_dir:\n      raise ValueError(\n          '`model_dir` provided in RunConfig construct, if set, '\n          'must have the same value as the model_dir in TF_CONFIG. '\n          'model_dir: {}\\nTF_CONFIG[\"model_dir\"]: {}.\\n'.format(\n              model_dir, model_dir_in_tf_config))\n\n    tf.compat.v1.logging.info('Using model_dir in TF_CONFIG: %s',\n                              model_dir_in_tf_config)\n\n  return model_dir or model_dir_in_tf_config\n\n\ndef path_to_str(path):\n  \"\"\"Returns the file system path representation of a `PathLike` object, else as it is.\n\n  Args:\n    path: An object that can be converted to path representation.\n\n  Returns:\n    A `str` object.\n  \"\"\"\n  if hasattr(path, '__fspath__'):\n    path = tf.compat.as_str_any(path.__fspath__())\n  return path\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/run_config_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"RunConfig tests.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport json\nimport tensorflow as tf\nfrom tensorflow.core.protobuf import rewriter_config_pb2\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\n\n\n_TEST_DIR = 'test_dir'\n_MASTER = 'master_'\n_NOT_SUPPORTED_REPLACE_PROPERTY_MSG = 'Replacing .*is not supported'\n_SAVE_CKPT_ERR = (\n    '`save_checkpoints_steps` and `save_checkpoints_secs` cannot be both set.')\n_MODEL_DIR_ERR = 'model_dir should be non-empty'\n_MODEL_DIR_TF_CONFIG_ERR = 'model_dir in TF_CONFIG should be non-empty'\n_MODEL_DIR_MISMATCH_ERR = (\n    '`model_dir` provided in RunConfig construct, if set, '\n    'must have the same value as the model_dir in TF_CONFIG. ')\n_SAVE_SUMMARY_STEPS_ERR = 'save_summary_steps should be >= 0'\n_SAVE_CKPT_STEPS_ERR = 'save_checkpoints_steps should be >= 0'\n_SAVE_CKPT_SECS_ERR = 'save_checkpoints_secs should be >= 0'\n_SESSION_CONFIG_ERR = 'session_config must be instance of ConfigProto'\n_KEEP_CKPT_MAX_ERR = 'keep_checkpoint_max should be >= 0'\n_KEEP_CKPT_HOURS_ERR = 'keep_checkpoint_every_n_hours should be > 0'\n_TF_RANDOM_SEED_ERR = 'tf_random_seed must be integer'\n_DEVICE_FN_ERR = 'device_fn must be callable with exactly one argument \"op\".'\n_ONE_CHIEF_ERR = 'The \"cluster\" in TF_CONFIG must have only one \"chief\" node.'\n_ONE_MASTER_ERR = 'The \"cluster\" in TF_CONFIG must have only one \"master\" node.'\n_MISSING_CHIEF_ERR = 'If \"cluster\" is set .* it must have one \"chief\" node'\n_MISSING_TASK_TYPE_ERR = 'If \"cluster\" is set .* task type must be set'\n_MISSING_TASK_ID_ERR = 'If \"cluster\" is set .* task index must be set'\n_INVALID_TASK_INDEX_ERR = 'is not a valid task_id'\n_NEGATIVE_TASK_INDEX_ERR = 'Task index must be non-negative number.'\n_INVALID_TASK_TYPE_ERR = 'is not a valid task_type'\n_INVALID_TASK_TYPE_FOR_LOCAL_ERR = (\n    'If \"cluster\" is not set in TF_CONFIG, task type must be WORKER.')\n_INVALID_TASK_INDEX_FOR_LOCAL_ERR = (\n    'If \"cluster\" is not set in TF_CONFIG, task index must be 0.')\n_INVALID_EVALUATOR_IN_CLUSTER_WITH_MASTER_ERR = (\n    'If `master` node exists in `cluster`, task_type `evaluator` is not '\n    'supported.')\n_INVALID_CHIEF_IN_CLUSTER_WITH_MASTER_ERR = (\n    'If `master` node exists in `cluster`, job `chief` is not supported.')\n_INVALID_SERVICE_TYPE_ERR = (\n    'If \"service\" is set in TF_CONFIG, it must be a dict. Given')\n_EXPERIMENTAL_MAX_WORKER_DELAY_SECS_ERR = (\n    'experimental_max_worker_delay_secs must be an integer if set.')\n_SESSION_CREATION_TIMEOUT_SECS_ERR = ('session_creation_timeout_secs should be '\n                                      '> 0')\n\n\ndef _create_run_config_with_cluster_spec(tf_config, **kwargs):\n  with tf.compat.v1.test.mock.patch.dict('os.environ',\n                                         {'TF_CONFIG': json.dumps(tf_config)}):\n    return run_config_lib.RunConfig(**kwargs)\n\n\nclass RunConfigTest(tf.test.TestCase):\n\n  def test_default_property_values(self):\n    config = run_config_lib.RunConfig()\n    self.assertIsNone(config.model_dir)\n    self.assertIsNone(config.session_config)\n    self.assertIsNone(config.tf_random_seed)\n    self.assertEqual(100, config.save_summary_steps)\n    self.assertEqual(600, config.save_checkpoints_secs)\n    self.assertIsNone(config.save_checkpoints_steps)\n    self.assertEqual(5, config.keep_checkpoint_max)\n    self.assertEqual(10000, config.keep_checkpoint_every_n_hours)\n    self.assertIsNone(config.service)\n    self.assertIsNone(config.device_fn)\n    self.assertIsNone(config.experimental_max_worker_delay_secs)\n    self.assertEqual(7200, config.session_creation_timeout_secs)\n    self.assertTrue(config.checkpoint_save_graph_def)\n\n  def test_model_dir(self):\n    empty_config = run_config_lib.RunConfig()\n    self.assertIsNone(empty_config.model_dir)\n\n    new_config = empty_config.replace(model_dir=_TEST_DIR)\n    self.assertEqual(_TEST_DIR, new_config.model_dir)\n\n  def test_replace_with_allowed_properties(self):\n    session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)\n    device_fn = lambda op: '/cpu:0'\n\n    config = run_config_lib.RunConfig().replace(\n        tf_random_seed=11,\n        save_summary_steps=12,\n        save_checkpoints_secs=14,\n        session_config=session_config,\n        keep_checkpoint_max=16,\n        keep_checkpoint_every_n_hours=17,\n        device_fn=device_fn,\n        session_creation_timeout_secs=18,\n        checkpoint_save_graph_def=False)\n    self.assertEqual(11, config.tf_random_seed)\n    self.assertEqual(12, config.save_summary_steps)\n    self.assertEqual(14, config.save_checkpoints_secs)\n    self.assertEqual(session_config, config.session_config)\n    self.assertEqual(16, config.keep_checkpoint_max)\n    self.assertEqual(17, config.keep_checkpoint_every_n_hours)\n    self.assertEqual(device_fn, config.device_fn)\n    self.assertEqual(18, config.session_creation_timeout_secs)\n    self.assertFalse(config.checkpoint_save_graph_def)\n\n  def test_replace_none_value(self):\n    config = run_config_lib.RunConfig().replace(\n        tf_random_seed=None,\n        model_dir=None,\n        save_summary_steps=None,\n        save_checkpoints_secs=None,\n        save_checkpoints_steps=None,\n        session_config=None,\n        keep_checkpoint_max=None,\n        keep_checkpoint_every_n_hours=None,\n        device_fn=None)\n    self.assertIsNone(config.tf_random_seed)\n    self.assertIsNone(config.model_dir)\n    self.assertIsNone(config.save_summary_steps)\n    self.assertIsNone(config.save_checkpoints_secs)\n    self.assertIsNone(config.save_checkpoints_steps)\n    self.assertIsNone(config.session_config)\n    self.assertIsNone(config.keep_checkpoint_max)\n    self.assertIsNone(config.keep_checkpoint_every_n_hours)\n    self.assertIsNone(config.device_fn)\n\n  def test_replace_with_disallowallowed_properties(self):\n    config = run_config_lib.RunConfig()\n    with self.assertRaises(ValueError):\n      # tf_random_seed is not allowed to be replaced.\n      config.replace(master='_master')\n    with self.assertRaises(ValueError):\n      config.replace(some_undefined_property=123)\n\n  def test_replace(self):\n    config = run_config_lib.RunConfig()\n\n    with self.assertRaisesRegexp(ValueError,\n                                 _NOT_SUPPORTED_REPLACE_PROPERTY_MSG):\n      # master is not allowed to be replaced.\n      config.replace(master=_MASTER)\n\n    with self.assertRaisesRegexp(ValueError,\n                                 _NOT_SUPPORTED_REPLACE_PROPERTY_MSG):\n      config.replace(some_undefined_property=_MASTER)\n\n  def test_replace_invalid_values(self):\n    config = run_config_lib.RunConfig()\n\n    with self.assertRaisesRegexp(ValueError, _MODEL_DIR_ERR):\n      config.replace(model_dir='')\n    with self.assertRaisesRegexp(ValueError, _SAVE_SUMMARY_STEPS_ERR):\n      config.replace(save_summary_steps=-1)\n    with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_STEPS_ERR):\n      config.replace(save_checkpoints_steps=-1)\n    with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_SECS_ERR):\n      config.replace(save_checkpoints_secs=-1)\n    with self.assertRaisesRegexp(ValueError, _SESSION_CONFIG_ERR):\n      config.replace(session_config={})\n    with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_MAX_ERR):\n      config.replace(keep_checkpoint_max=-1)\n    with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_HOURS_ERR):\n      config.replace(keep_checkpoint_every_n_hours=0)\n    with self.assertRaisesRegexp(ValueError,\n                                 _SESSION_CREATION_TIMEOUT_SECS_ERR):\n      config.replace(session_creation_timeout_secs=0)\n    with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):\n      config.replace(tf_random_seed=1.0)\n    with self.assertRaisesRegexp(ValueError, _DEVICE_FN_ERR):\n      config.replace(device_fn=lambda x, y: 0)\n    with self.assertRaisesRegexp(ValueError,\n                                 _EXPERIMENTAL_MAX_WORKER_DELAY_SECS_ERR):\n      config.replace(experimental_max_worker_delay_secs='5')\n\n  def test_init_with_allowed_properties(self):\n    session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)\n    device_fn = lambda op: '/cpu:0'\n\n    config = run_config_lib.RunConfig(\n        tf_random_seed=11,\n        save_summary_steps=12,\n        save_checkpoints_secs=14,\n        session_config=session_config,\n        keep_checkpoint_max=16,\n        keep_checkpoint_every_n_hours=17,\n        device_fn=device_fn,\n        experimental_max_worker_delay_secs=10)\n    self.assertEqual(11, config.tf_random_seed)\n    self.assertEqual(12, config.save_summary_steps)\n    self.assertEqual(14, config.save_checkpoints_secs)\n    self.assertEqual(session_config, config.session_config)\n    self.assertEqual(16, config.keep_checkpoint_max)\n    self.assertEqual(17, config.keep_checkpoint_every_n_hours)\n    self.assertEqual(device_fn, config.device_fn)\n    self.assertEqual(10, config.experimental_max_worker_delay_secs)\n\n  def test_init_none_value(self):\n    config = run_config_lib.RunConfig(\n        tf_random_seed=None,\n        model_dir=None,\n        save_summary_steps=None,\n        save_checkpoints_secs=None,\n        save_checkpoints_steps=None,\n        session_config=None,\n        keep_checkpoint_max=None,\n        keep_checkpoint_every_n_hours=None,\n        device_fn=None)\n    self.assertIsNone(config.tf_random_seed)\n    self.assertIsNone(config.model_dir)\n    self.assertIsNone(config.save_summary_steps)\n    self.assertIsNone(config.save_checkpoints_secs)\n    self.assertIsNone(config.save_checkpoints_steps)\n    self.assertIsNone(config.session_config)\n    self.assertIsNone(config.keep_checkpoint_max)\n    self.assertIsNone(config.keep_checkpoint_every_n_hours)\n    self.assertIsNone(config.device_fn)\n\n  def test_init_invalid_values(self):\n    with self.assertRaisesRegexp(ValueError, _MODEL_DIR_ERR):\n      run_config_lib.RunConfig(model_dir='')\n    with self.assertRaisesRegexp(ValueError, _SAVE_SUMMARY_STEPS_ERR):\n      run_config_lib.RunConfig(save_summary_steps=-1)\n    with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_STEPS_ERR):\n      run_config_lib.RunConfig(save_checkpoints_steps=-1)\n    with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_SECS_ERR):\n      run_config_lib.RunConfig(save_checkpoints_secs=-1)\n    with self.assertRaisesRegexp(ValueError, _SESSION_CONFIG_ERR):\n      run_config_lib.RunConfig(session_config={})\n    with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_MAX_ERR):\n      run_config_lib.RunConfig(keep_checkpoint_max=-1)\n    with self.assertRaisesRegexp(ValueError, _KEEP_CKPT_HOURS_ERR):\n      run_config_lib.RunConfig(keep_checkpoint_every_n_hours=0)\n    with self.assertRaisesRegexp(ValueError, _TF_RANDOM_SEED_ERR):\n      run_config_lib.RunConfig(tf_random_seed=1.0)\n    with self.assertRaisesRegexp(ValueError, _DEVICE_FN_ERR):\n      run_config_lib.RunConfig(device_fn=lambda x: '/cpu:0')\n    with self.assertRaisesRegexp(ValueError,\n                                 _EXPERIMENTAL_MAX_WORKER_DELAY_SECS_ERR):\n      run_config_lib.RunConfig(experimental_max_worker_delay_secs='5')\n\n  def test_incompatible_train_strategy(self):\n    with self.assertRaisesRegex(\n        ValueError, 'Please use `tf.compat.v1.distribut'\n        'e.experimental.ParameterServerStrategy`'):\n      run_config_lib.RunConfig(\n          train_distribute=tf.compat.v2.distribute.experimental\n          .ParameterServerStrategy.__new__(\n              tf.compat.v2.distribute.experimental.ParameterServerStrategy))\n\n  def test_incompatible_eval_strategy(self):\n    with self.assertRaisesRegex(\n        ValueError, 'Please use `tf.compat.v1.distribut'\n        'e.experimental.ParameterServerStrategy`'):\n      run_config_lib.RunConfig(\n          eval_distribute=tf.compat.v2.distribute.experimental.ParameterServerStrategy\n          .__new__(tf.compat.v2.distribute.experimental.ParameterServerStrategy))\n\n\nclass RunConfigDistributedSettingTest(tf.test.TestCase):\n\n  def _assert_distributed_properties(\n      self, run_config, expected_cluster_spec, expected_task_type,\n      expected_task_id, expected_master, expected_evaluation_master,\n      expected_is_chief, expected_num_worker_replicas,\n      expected_num_ps_replicas):\n    self.assertEqual(expected_cluster_spec, run_config.cluster_spec.as_dict())\n    self.assertEqual(expected_task_type, run_config.task_type)\n    self.assertEqual(expected_task_id, run_config.task_id)\n    self.assertEqual(expected_master, run_config.master)\n    self.assertEqual(expected_evaluation_master, run_config.evaluation_master)\n    self.assertEqual(expected_is_chief, run_config.is_chief)\n    self.assertEqual(expected_num_worker_replicas,\n                     run_config.num_worker_replicas)\n    self.assertEqual(expected_num_ps_replicas, run_config.num_ps_replicas)\n\n  def test_default_values(self):\n    self._assert_distributed_properties(\n        run_config=run_config_lib.RunConfig(),\n        expected_cluster_spec={},\n        expected_task_type=run_config_lib.TaskType.WORKER,\n        expected_task_id=0,\n        expected_master='',\n        expected_evaluation_master='',\n        expected_is_chief=True,\n        expected_num_worker_replicas=1,\n        expected_num_ps_replicas=0)\n\n  def test_tf_config_for_local(self):\n    tf_config = {'task': {'type': run_config_lib.TaskType.WORKER, 'index': 0}}\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self._assert_distributed_properties(\n        run_config=run_config,\n        expected_cluster_spec={},\n        expected_task_type=run_config_lib.TaskType.WORKER,\n        expected_task_id=0,\n        expected_master='',\n        expected_evaluation_master='',\n        expected_is_chief=True,\n        expected_num_worker_replicas=1,\n        expected_num_ps_replicas=0)\n    self.assertEqual(0, run_config.global_id_in_cluster)\n    self.assertIsNone(run_config.session_config, None)\n\n  def test_session_master_for_local(self):\n    tf_config = {'session_master': '_my_master'}\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec={},\n        expected_task_type=run_config_lib.TaskType.WORKER,\n        expected_task_id=0,\n        expected_master='_my_master',\n        expected_evaluation_master='',\n        expected_is_chief=True,\n        expected_num_worker_replicas=1,\n        expected_num_ps_replicas=0)\n\n  def test_eval_session_master_for_local(self):\n    tf_config = {'eval_session_master': '_my_eval_master'}\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec={},\n        expected_task_type=run_config_lib.TaskType.WORKER,\n        expected_task_id=0,\n        expected_master='',\n        expected_evaluation_master='_my_eval_master',\n        expected_is_chief=True,\n        expected_num_worker_replicas=1,\n        expected_num_ps_replicas=0)\n\n  def test_invalid_task_type_for_local(self):\n    tf_config = {'task': {'type': run_config_lib.TaskType.CHIEF, 'index': 0}}\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE_FOR_LOCAL_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_invalid_task_index_for_local(self):\n    tf_config = {'task': {'type': run_config_lib.TaskType.WORKER, 'index': 1}}\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_INDEX_FOR_LOCAL_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_chief_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0\n        }\n    }\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec=tf_config['cluster'],\n        expected_task_type=run_config_lib.TaskType.CHIEF,\n        expected_task_id=0,\n        expected_master='grpc://host0:0',\n        expected_evaluation_master='',\n        expected_is_chief=True,\n        expected_num_worker_replicas=4,\n        expected_num_ps_replicas=2)\n\n  def test_session_master_from_single_node_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0\n        },\n        'session_master': '_my_master'\n    }\n    self.assertEqual('_my_master',\n                     _create_run_config_with_cluster_spec(tf_config).master)\n\n  def test_session_master_from_multiple_nodes_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0\n        },\n        'session_master': '_my_master'\n    }\n    self.assertEqual('_my_master',\n                     _create_run_config_with_cluster_spec(tf_config).master)\n\n  def test_fail_with_multiple_chief_nodes(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0', 'host:6:6'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n    }\n    with self.assertRaisesRegexp(ValueError, _ONE_CHIEF_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_missing_chief_node(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n    }\n    with self.assertRaisesRegexp(ValueError, _MISSING_CHIEF_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_single_chief_node(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0\n        }\n    }\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec=tf_config['cluster'],\n        expected_task_type=run_config_lib.TaskType.CHIEF,\n        expected_task_id=0,\n        expected_master='',\n        expected_evaluation_master='',\n        expected_is_chief=True,\n        expected_num_worker_replicas=1,\n        expected_num_ps_replicas=0)\n\n  def test_fail_with_missing_task_type_for_distributed(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host3:3']\n        },\n    }\n    with self.assertRaisesRegexp(ValueError, _MISSING_TASK_TYPE_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_missing_task_index_for_distributed(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host3:3']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n        }\n    }\n    with self.assertRaisesRegexp(ValueError, _MISSING_TASK_ID_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_index_is_too_large(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host3:3']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 1\n        }\n    }\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_INDEX_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_invalid_task_index(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host3:3']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': -1\n        }\n    }\n    with self.assertRaisesRegexp(ValueError, _NEGATIVE_TASK_INDEX_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_invalid_task_type(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host3:3']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.WORKER,\n            'index': 0\n        }\n    }\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_worker_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.WORKER,\n            'index': 1\n        }\n    }\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec=tf_config['cluster'],\n        expected_task_type=run_config_lib.TaskType.WORKER,\n        expected_task_id=1,\n        expected_master='grpc://host4:4',\n        expected_evaluation_master='',\n        expected_is_chief=False,\n        expected_num_worker_replicas=4,\n        expected_num_ps_replicas=2)\n\n  def test_ps_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.PS,\n            'index': 0\n        }\n    }\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec=tf_config['cluster'],\n        expected_task_type=run_config_lib.TaskType.PS,\n        expected_task_id=0,\n        expected_master='grpc://host1:1',\n        expected_evaluation_master='',\n        expected_is_chief=False,\n        expected_num_worker_replicas=4,\n        expected_num_ps_replicas=2)\n\n  def test_evaluator_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.EVALUATOR,\n            'index': 12\n        }\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self._assert_distributed_properties(\n        run_config=run_config,\n        expected_cluster_spec={},\n        expected_task_type=run_config_lib.TaskType.EVALUATOR,\n        expected_task_id=12,\n        expected_master='',\n        expected_evaluation_master='',\n        expected_is_chief=False,  # evaluator is never chief.\n        expected_num_worker_replicas=0,  # evaluator is not in training cluster.\n        expected_num_ps_replicas=0)\n    self.assertIsNone(run_config.global_id_in_cluster)\n\n  def test_eval_master_for_evaluator(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.EVALUATOR,\n            'index': 12\n        },\n        'eval_session_master': 'grpc://123',\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual('grpc://123', run_config.evaluation_master)\n\n  def test_fail_with_invalid_task_index_for_evaluator(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host3:3']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.EVALUATOR,\n            'index': -1\n        }\n    }\n    with self.assertRaisesRegexp(ValueError, _NEGATIVE_TASK_INDEX_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_global_id_in_cluster_for_chief(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'],\n            run_config_lib.TaskType.PS: ['host6:3', 'host7:4', 'host8:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0,\n        },\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual(0, run_config.global_id_in_cluster)\n\n  def test_global_id_in_cluster_for_worker(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'],\n            run_config_lib.TaskType.PS: ['host6:3', 'host7:4', 'host8:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.WORKER,\n            'index': 2,\n        },\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual(3, run_config.global_id_in_cluster)\n\n  def test_global_id_in_cluster_for_ps(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'],\n            run_config_lib.TaskType.PS: ['host6:3', 'host7:4', 'host8:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.PS,\n            'index': 1,\n        },\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual(5, run_config.global_id_in_cluster)\n\n  def test_global_id_in_cluster_for_multipe_worker_types(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            'worker': ['host3:3', 'host4:4', 'host5:5'],\n            'other_type': ['host3:1', 'host4:2'],\n            run_config_lib.TaskType.PS: ['host6:3', 'host7:4', 'host8:5']\n        },\n        'task': {\n            'type': 'other_type',\n            'index': 1,\n        },\n    }\n    # Though 'other_type' is defined after 'worker', based on alphabetical\n    # order, the task type order should be 'chief', 'other_type', 'worker',\n    # 'ps', where 'chief' and 'ps' are predefined to be the top and last in the\n    # order list.\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual(2, run_config.global_id_in_cluster)\n\n\nclass RunConfigDistributedSettingWithMasterTest(tf.test.TestCase):\n\n  def _assert_distributed_properties(\n      self, run_config, expected_cluster_spec, expected_task_type,\n      expected_task_id, expected_master, expected_evaluation_master,\n      expected_is_chief, expected_num_worker_replicas,\n      expected_num_ps_replicas):\n    self.assertEqual(expected_cluster_spec, run_config.cluster_spec.as_dict())\n    self.assertEqual(expected_task_type, run_config.task_type)\n    self.assertEqual(expected_task_id, run_config.task_id)\n    self.assertEqual(expected_master, run_config.master)\n    self.assertEqual(expected_evaluation_master, run_config.evaluation_master)\n    self.assertEqual(expected_is_chief, run_config.is_chief)\n    self.assertEqual(expected_num_worker_replicas,\n                     run_config.num_worker_replicas)\n    self.assertEqual(expected_num_ps_replicas, run_config.num_ps_replicas)\n\n  def test_invalid_task_type_for_local(self):\n    tf_config = {'task': {'type': run_config_lib.TaskType.MASTER, 'index': 0}}\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE_FOR_LOCAL_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_master_node(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n            'index': 0\n        }\n    }\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec=tf_config['cluster'],\n        expected_task_type=run_config_lib.TaskType.MASTER,\n        expected_task_id=0,\n        expected_master='grpc://host0:0',\n        expected_evaluation_master='',\n        expected_is_chief=True,\n        expected_num_worker_replicas=4,\n        expected_num_ps_replicas=2)\n\n  def test_session_master_in_single_node_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n            'index': 0\n        },\n        'session_master': '_my_master'\n    }\n    self.assertEqual('_my_master',\n                     _create_run_config_with_cluster_spec(tf_config).master)\n\n  def test_session_master_in_multiple_nodes_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n            'index': 0\n        },\n        'session_master': '_my_master'\n    }\n    self.assertEqual('_my_master',\n                     _create_run_config_with_cluster_spec(tf_config).master)\n\n  def test_fail_with_multiple_master_nodes(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0', 'host:6:6'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n    }\n    with self.assertRaisesRegexp(ValueError, _ONE_MASTER_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_single_master_node(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n            'index': 0\n        }\n    }\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec=tf_config['cluster'],\n        expected_task_type=run_config_lib.TaskType.MASTER,\n        expected_task_id=0,\n        expected_master='',\n        expected_evaluation_master='',\n        expected_is_chief=True,\n        expected_num_worker_replicas=1,\n        expected_num_ps_replicas=0)\n\n  def test_fail_with_missing_task_type_for_distributed(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host3:3']\n        },\n    }\n    with self.assertRaisesRegexp(ValueError, _MISSING_TASK_TYPE_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_missing_task_index_for_distributed(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host3:3']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n        }\n    }\n    with self.assertRaisesRegexp(ValueError, _MISSING_TASK_ID_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_index_is_too_large(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host3:3']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n            'index': 1\n        }\n    }\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_INDEX_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_invalid_task_index(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host3:3']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n            'index': -1\n        }\n    }\n    with self.assertRaisesRegexp(ValueError, _NEGATIVE_TASK_INDEX_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_invalid_task_type(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host3:3']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.WORKER,\n            'index': 0\n        }\n    }\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_worker_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.WORKER,\n            'index': 1\n        }\n    }\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec=tf_config['cluster'],\n        expected_task_type=run_config_lib.TaskType.WORKER,\n        expected_task_id=1,\n        expected_master='grpc://host4:4',\n        expected_evaluation_master='',\n        expected_is_chief=False,\n        expected_num_worker_replicas=4,\n        expected_num_ps_replicas=2)\n\n  def test_ps_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.PS,\n            'index': 0\n        }\n    }\n    self._assert_distributed_properties(\n        run_config=_create_run_config_with_cluster_spec(tf_config),\n        expected_cluster_spec=tf_config['cluster'],\n        expected_task_type=run_config_lib.TaskType.PS,\n        expected_task_id=0,\n        expected_master='grpc://host1:1',\n        expected_evaluation_master='',\n        expected_is_chief=False,\n        expected_num_worker_replicas=4,\n        expected_num_ps_replicas=2)\n\n  def test_fail_with_evaluator(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.EVALUATOR,\n            'index': 1\n        }\n    }\n    with self.assertRaisesRegexp(ValueError,\n                                 _INVALID_EVALUATOR_IN_CLUSTER_WITH_MASTER_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_fail_with_chief(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.CHIEF: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.PS,\n            'index': 1\n        }\n    }\n    with self.assertRaisesRegexp(ValueError,\n                                 _INVALID_CHIEF_IN_CLUSTER_WITH_MASTER_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n  def test_global_id_in_cluster_for_master(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'],\n            run_config_lib.TaskType.PS: ['host6:3', 'host7:4', 'host8:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n            'index': 0,\n        },\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual(0, run_config.global_id_in_cluster)\n\n  def test_global_id_in_cluster_for_worker(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'],\n            run_config_lib.TaskType.PS: ['host6:3', 'host7:4', 'host8:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.WORKER,\n            'index': 2,\n        },\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual(3, run_config.global_id_in_cluster)\n\n  def test_global_id_in_cluster_for_ps(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5'],\n            run_config_lib.TaskType.PS: ['host6:3', 'host7:4', 'host8:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.PS,\n            'index': 1,\n        },\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual(5, run_config.global_id_in_cluster)\n\n  def test_global_id_in_cluster_for_multipe_worker_types(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            'worker': ['host3:3', 'host4:4', 'host5:5'],\n            'other_type': ['host3:1', 'host4:2'],\n            run_config_lib.TaskType.PS: ['host6:3', 'host7:4', 'host8:5']\n        },\n        'task': {\n            'type': 'other_type',\n            'index': 1,\n        },\n    }\n    # Though 'other_type' is defined after 'worker', based on alphabetical\n    # order, the task type order should be 'chief', 'other_type', 'worker',\n    # 'ps', where 'chief' and 'ps' are predefined to be the top and last in the\n    # order list.\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual(2, run_config.global_id_in_cluster)\n\n\nclass RunConfigSaveCheckpointsTest(tf.test.TestCase):\n\n  def test_save_checkpoint(self):\n    empty_config = run_config_lib.RunConfig()\n    self.assertEqual(600, empty_config.save_checkpoints_secs)\n    self.assertIsNone(empty_config.save_checkpoints_steps)\n\n    config_with_steps = empty_config.replace(save_checkpoints_steps=100)\n    del empty_config\n    self.assertEqual(100, config_with_steps.save_checkpoints_steps)\n    self.assertIsNone(config_with_steps.save_checkpoints_secs)\n\n    config_with_secs = config_with_steps.replace(save_checkpoints_secs=200)\n    del config_with_steps\n    self.assertEqual(200, config_with_secs.save_checkpoints_secs)\n    self.assertIsNone(config_with_secs.save_checkpoints_steps)\n\n  def test_save_checkpoint_both_steps_and_secs_are_not_none(self):\n    empty_config = run_config_lib.RunConfig()\n    with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_ERR):\n      empty_config.replace(\n          save_checkpoints_steps=100, save_checkpoints_secs=200)\n\n    with self.assertRaisesRegexp(ValueError, _SAVE_CKPT_ERR):\n      run_config_lib.RunConfig(\n          save_checkpoints_steps=100, save_checkpoints_secs=200)\n\n  def test_save_checkpoint_both_steps_and_secs_are_none(self):\n    config_with_secs = run_config_lib.RunConfig()\n    config_without_ckpt = config_with_secs.replace(\n        save_checkpoints_steps=None, save_checkpoints_secs=None)\n    self.assertIsNone(config_without_ckpt.save_checkpoints_steps)\n    self.assertIsNone(config_without_ckpt.save_checkpoints_secs)\n\n  def test_save_checkpoint_flip_secs_to_none(self):\n    config_with_secs = run_config_lib.RunConfig()\n    config_without_ckpt = config_with_secs.replace(save_checkpoints_secs=None)\n    self.assertIsNone(config_without_ckpt.save_checkpoints_steps)\n    self.assertIsNone(config_without_ckpt.save_checkpoints_secs)\n\n  def test_save_checkpoint_flip_steps_to_none(self):\n    config_with_steps = run_config_lib.RunConfig().replace(\n        save_checkpoints_steps=100)\n    config_without_ckpt = config_with_steps.replace(save_checkpoints_steps=None)\n    self.assertIsNone(config_without_ckpt.save_checkpoints_steps)\n    self.assertIsNone(config_without_ckpt.save_checkpoints_secs)\n\n\nclass RunConfigServiceKeyTest(tf.test.TestCase):\n\n  def test_arbitrary_key_value_pairs(self):\n    tf_config = {\n        'service': {\n            'key1': [1, 2],\n            'key2': {\n                'a': 3,\n                'b': 4\n            },\n            'key3': 789,\n        },\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual(tf_config['service'], run_config.service)\n\n  def test_missing_service_key(self):\n    tf_config = {\n        'model_dir': '/tmp/123',\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertIsNone(run_config.service)\n\n  def test_fail_with_non_dict(self):\n    tf_config = {\n        'service': 789,\n    }\n    with self.assertRaisesRegexp(TypeError, _INVALID_SERVICE_TYPE_ERR):\n      _create_run_config_with_cluster_spec(tf_config)\n\n\nclass RunConfigModelDirTest(tf.test.TestCase):\n\n  def test_default(self):\n    run_config = run_config_lib.RunConfig()\n    self.assertIsNone(run_config.model_dir)\n\n  def test_model_dir_in_constructor(self):\n    run_config = run_config_lib.RunConfig(model_dir='/tmp/123')\n    self.assertEqual('/tmp/123', run_config.model_dir)\n\n  def test_model_dir_in_tf_config(self):\n    tf_config = {\n        'model_dir': '/tmp/123',\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertEqual('/tmp/123', run_config.model_dir)\n\n  def test_model_dir_both_set_in_both_constructor_and_tf_config(self):\n    model_dir = '/tmp/123'\n    tf_config = {'model_dir': model_dir}\n    kwargs = {'model_dir': model_dir}\n    run_config = _create_run_config_with_cluster_spec(tf_config, **kwargs)\n    self.assertEqual('/tmp/123', run_config.model_dir)\n\n  def test_model_dir_different_in_both_constructor_and_tf_config(self):\n    tf_config = {'model_dir': '/tmp/123'}\n    kwargs = {'model_dir': '/tmp/456'}\n    with self.assertRaisesRegexp(ValueError, _MODEL_DIR_MISMATCH_ERR):\n      _create_run_config_with_cluster_spec(tf_config, **kwargs)\n\n  def test_fail_with_empty_string_in_constructor(self):\n    with self.assertRaisesRegexp(ValueError, _MODEL_DIR_ERR):\n      run_config_lib.RunConfig(model_dir='')\n\n  def test_fail_with_empty_string_in_tf_config(self):\n    with self.assertRaisesRegexp(ValueError, _MODEL_DIR_TF_CONFIG_ERR):\n      tf_config = {'model_dir': ''}\n      _create_run_config_with_cluster_spec(tf_config)\n\n\nclass RunConfigSessionConfigTest(tf.test.TestCase):\n\n  def _assert_equal_session_config(self, session_config,\n                                   expected_device_filters):\n\n    rewrite_opts = rewriter_config_pb2.RewriterConfig(\n        meta_optimizer_iterations=rewriter_config_pb2.RewriterConfig.ONE)\n    graph_opts = tf.compat.v1.GraphOptions(rewrite_options=rewrite_opts)\n    expected_session_config = tf.compat.v1.ConfigProto(\n        allow_soft_placement=True,\n        graph_options=graph_opts,\n        device_filters=expected_device_filters)\n    self.assertEqual(session_config, expected_session_config)\n\n  def test_master_session_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.MASTER,\n            'index': 0\n        }\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self._assert_equal_session_config(run_config.session_config,\n                                      ['/job:ps', '/job:master'])\n\n  def test_chief_session_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0\n        }\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self._assert_equal_session_config(run_config.session_config,\n                                      ['/job:ps', '/job:chief'])\n\n  def test_worker_session_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.WORKER,\n            'index': 1\n        }\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self._assert_equal_session_config(run_config.session_config,\n                                      ['/job:ps', '/job:worker/task:1'])\n\n  def test_ps_session_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.PS,\n            'index': 1\n        }\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self._assert_equal_session_config(\n        run_config.session_config,\n        ['/job:ps', '/job:worker', '/job:chief', '/job:master'])\n\n  def test_evaluator_session_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.EVALUATOR,\n            'index': 0\n        }\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertIsNone(run_config.session_config)\n\n  def test_other_type_session_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.MASTER: ['host0:0'],\n            run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n            'other_type': ['host3:1', 'host4:2'],\n            run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4', 'host5:5']\n        },\n        'task': {\n            'type': 'other_type',\n            'index': 0\n        }\n    }\n    run_config = _create_run_config_with_cluster_spec(tf_config)\n    self.assertIsNone(run_config.session_config)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tf_estimator_doctest.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Run doctests for tensorflow.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport re\nimport sys\nimport textwrap\nimport tensorflow as tf\nimport numpy as np\n\nfrom absl import flags\nfrom absl.testing import absltest\n\nimport tensorflow_estimator.python.estimator.estimator_lib as tfe\n\nimport tensorflow.compat.v2 as tf\ntf.estimator = tfe\ntf.compat.v1.enable_v2_behavior()\n\n# We put doctest after absltest so that it picks up the unittest monkeypatch.\n# Otherwise doctest tests aren't runnable at all.\nimport doctest  # pylint: disable=g-import-not-at-top, g-bad-import-order\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string('module', None, 'A specific module to run doctest on.')\nflags.DEFINE_boolean('list', None,\n                     'List all the modules in the core package imported.')\nflags.DEFINE_string('file', None, 'A specific file to run doctest on.')\n\nflags.mark_flags_as_mutual_exclusive(['module', 'file'])\nflags.mark_flags_as_mutual_exclusive(['list', 'file'])\n\nPACKAGE = 'tensorflow_estimator.python.'\n\n\ndef find_modules():\n  \"\"\"Finds all the modules in the core package imported.\n\n  Returns:\n    A list containing all the modules in tensorflow.python.\n  \"\"\"\n\n  tf_modules = []\n  for name, module in sys.modules.items():\n    if name.startswith(PACKAGE):\n      tf_modules.append(module)\n\n  return tf_modules\n\n\ndef filter_on_submodules(all_modules, submodule):\n  \"\"\"Filters all the modules based on the module flag.\n\n  The module flag has to be relative to the core package imported.\n  For example, if `submodule=keras.layers` then, this function will return\n  all the modules in the submodule.\n\n  Args:\n    all_modules: All the modules in the core package.\n    submodule: Submodule to filter from all the modules.\n\n  Returns:\n    All the modules in the submodule.\n  \"\"\"\n\n  filtered_modules = [\n      mod for mod in all_modules if PACKAGE + submodule in mod.__name__\n  ]\n  return filtered_modules\n\n\ndef get_module_and_inject_docstring(file_path):\n  \"\"\"Replaces the docstring of the module with the changed file's content.\n\n  Args:\n    file_path: Path to the file\n\n  Returns:\n    A list containing the module changed by the file.\n  \"\"\"\n\n  file_path = os.path.abspath(file_path)\n  mod_index = file_path.find(PACKAGE.replace('.', os.sep))\n  file_mod_name, _ = os.path.splitext(file_path[mod_index:])\n  file_module = sys.modules[file_mod_name.replace(os.sep, '.')]\n\n  with open(file_path, 'r') as f:\n    content = f.read()\n\n  file_module.__doc__ = content\n\n  return [file_module]\n\n\nclass TfTestCase(tf.test.TestCase):\n\n  def set_up(self, test):\n    self.setUp()\n\n  def tear_down(self, test):\n    self.tearDown()\n\n\nclass CustomOutputChecker(doctest.OutputChecker):\n  \"\"\"Changes the `want` and `got` strings.\n\n  This allows it to be customized before they are compared.\n  \"\"\"\n  ID_RE = re.compile(r'\\bid=(\\d+)\\b')\n  ADDRESS_RE = re.compile(r'\\bat 0x[0-9a-f]*?>')\n\n  def check_output(self, want, got, optionflags):\n    # Replace tf.Tensor's id with ellipsis(...) because tensor's id can change\n    # on each execution. Users may forget to use ellipsis while writing\n    # examples in docstrings, so replacing the id with `...` makes it safe.\n    want = self.ID_RE.sub('id=...', want)\n    want = self.ADDRESS_RE.sub('at ...>', want)\n    return doctest.OutputChecker.check_output(self, want, got, optionflags)\n\n  _MESSAGE = textwrap.dedent(\"\"\"\\n\n        #############################################################\n        Check the documentation\n        (go/testable-docstrings) on how to write testable docstrings.\n        #############################################################\"\"\")\n\n  def output_difference(self, example, got, optionflags):\n    got = got + self._MESSAGE\n    return doctest.OutputChecker.output_difference(self, example, got,\n                                                   optionflags)\n\n\ndef load_tests(unused_loader, tests, unused_ignore):\n  \"\"\"Loads all the tests in the docstrings and runs them.\"\"\"\n\n  tf_modules = find_modules()\n\n  if FLAGS.module:\n    tf_modules = filter_on_submodules(tf_modules, FLAGS.module)\n\n  if FLAGS.list:\n    print('**************************************************')\n    for mod in tf_modules:\n      print(mod.__name__)\n    print('**************************************************')\n    return tests\n\n  if FLAGS.file:\n    tf_modules = get_module_and_inject_docstring(FLAGS.file)\n\n  for module in tf_modules:\n    testcase = TfTestCase()\n    tests.addTests(\n        doctest.DocTestSuite(\n            module,\n            test_finder=doctest.DocTestFinder(exclude_empty=False),\n            extraglobs={\n                'tf': tf,\n                'np': np,\n                'os': os\n            },\n            setUp=testcase.set_up,\n            tearDown=testcase.tear_down,\n            checker=CustomOutputChecker(),\n            optionflags=(doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE\n                         | doctest.IGNORE_EXCEPTION_DETAIL\n                         | doctest.DONT_ACCEPT_BLANKLINE),\n        ))\n  return tests\n\n\nif __name__ == '__main__':\n  absltest.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tools/__init__.py",
    "content": ""
  },
  {
    "path": "tensorflow_estimator/python/estimator/tools/analytics.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Analytics helpers library.\"\"\"\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n\ndef track_usage(tool_id, tags):\n  \"\"\"No usage tracking for external library.\n\n  Args:\n    tool_id: A string identifier for tool to be tracked.\n    tags: list of string tags that will be added to the tracking.\n  \"\"\"\n  del tool_id, tags  # Unused externally.\n\n\ndef track_numerical_issues(exc_info):\n  \"\"\"No tracking for external library.\n\n  Args:\n    exc_info: Output from `sys.exc_info` (type, value, traceback)\n  \"\"\"\n  del exc_info\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tools/checkpoint_converter.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\nr\"\"\"Checkpoint converter for Canned Estimators in TF 1.x.\n\nThis checkpoint converter tool is mainly for Canned Estimators, including DNN\nLinear and DNNLinearCombined estimators. The allowed optimizers to be converted\ninclude Adam, Adagrad, Ftrl, RMSProp, and SGD.\n\nNote that, this converter is not suitable for the case where 'dnn_optimizer'\nand 'linear_optimizer' in DNNLinearCombined model are the same.\n\nIf your current canned estimators and checkpoints are from TF 1.x, after you\nmigrate the canned estimator to v2 with `tf_keras.optimizers.*`, the converted\ncheckpoint allow you to restore and retrain the model in TF 2.0.\n\nUsage:\n  python checkpoint_convert.py '/path/to/checkpoint' '/path/to/graph.pbtxt' \\\n      '/path/to/new_checkpoint'\n\nFor example, if there is a V1 checkpoint to be converted and the files include:\n  /tmp/my_checkpoint/model.ckpt-100.data-00000-of-00001\n  /tmp/my_checkpoint/model.ckpt-100.index\n  /tmp/my_checkpoint/model.ckpt-100.meta\n  /tmp/my_checkpoint/graph.pbtxt\n\nuse the following command:\n  mkdir /tmp/my_converted_checkpoint &&\n  python checkpoint_convert.py \\\n      /tmp/my_checkpoint/model.ckpt-100 /tmp/my_checkpoint/graph.pbtxt \\\n      /tmp/my_converted_checkpoint/model.ckpt-100\n\nThis will generate three converted checkpoint files corresponding to the three\nold checkpoint files in the new directory:\n  /tmp/my_converted_checkpoint/model.ckpt-100.data-00000-of-00001\n  /tmp/my_converted_checkpoint/model.ckpt-100.index\n  /tmp/my_converted_checkpoint/model.ckpt-100.meta\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport argparse\nimport sys\nimport tensorflow as tf\nfrom google.protobuf import text_format\nfrom tensorflow_estimator.python.estimator.util import tf_keras\n\n# Optimizer name mapping from v1 to v2.\nOPT_NAME_V1_TO_V2 = {\n    'Adagrad': 'Adagrad',\n    'RMSProp': 'RMSprop',\n    'Ftrl': 'Ftrl',\n    'Adam': 'Adam',\n    'SGD': 'SGD',\n}\n\n# Hyper-paratmeters of optimizer in checkpoint.\nHP_IN_CKPT = {\n    'Adam': {\n        'beta1_power': 'training/Adam/beta_1',\n        'beta2_power': 'training/Adam/beta_2',\n    },\n}\n\n# Optimzier variable name mapping from v1 to v2.\nOPT_VAR_NAME_V1_TO_V2 = {\n    'Adam': {\n        'Adam': 'm',\n        'Adam_1': 'v',\n    },\n    'Ftrl': {\n        'Ftrl': 'accumulator',\n        'Ftrl_1': 'linear',\n    },\n    'RMSProp': {\n        'RMSProp': 'rms',\n        'RMSProp_1': None,\n    },\n    'Adagrad': {\n        'Adagrad': 'accumulator',\n    },\n}\n\n# Hyper-paratmeters of optimizer in graph.\nHP_IN_GRAPH = {\n    'Adam': ['decay', 'learning_rate'],\n    'Ftrl': [\n        'decay', 'l1_regularization_strength', 'l2_regularization_strength',\n        'beta', 'learning_rate', 'learning_rate_power'\n    ],\n    'RMSProp': ['decay', 'learning_rate', 'momentum', 'rho'],\n    'Adagrad': ['decay', 'learning_rate'],\n    'SGD': ['decay', 'learning_rate', 'momentum'],\n}\n\n# optimizer v2 instance.\nOPT_V2_INSTANCE = {\n    'Adagrad': tf_keras.optimizers.legacy.Adagrad(),\n    'Adam': tf_keras.optimizers.legacy.Adam(),\n    'Ftrl': tf_keras.optimizers.legacy.Ftrl(),\n    'RMSProp': tf_keras.optimizers.legacy.RMSprop(),\n    'SGD': tf_keras.optimizers.legacy.SGD(),\n}\n\n\ndef _add_new_variable(initial_value, var_name_v2, var_name_v1, var_map,\n                      var_names_map):\n  \"\"\"Creates a new variable and add it to the variable maps.\"\"\"\n  var = tf.Variable(initial_value, name=var_name_v2)\n  var_map[var_name_v2] = var\n  var_names_map[var_name_v2] = var_name_v1\n\n\ndef _add_opt_variable(opt_name_v2, var_name_v1, idx, suffix_v2, reader, var_map,\n                      var_names_map):\n  \"\"\"Adds a new optimizer v2 variable.\"\"\"\n  var_name_v2 = 'training/' + opt_name_v2 + '/' + var_name_v1[:idx] + suffix_v2\n  tensor = reader.get_tensor(var_name_v1)\n  _add_new_variable(tensor, var_name_v2, var_name_v1, var_map, var_names_map)\n\n\ndef _convert_variables_in_ckpt(opt_name_v1, reader, variable_names, var_map,\n                               var_names_map, est_type):\n  \"\"\"Converts all variables in checkpoint from v1 to v2.\"\"\"\n  global_step = None\n  hp_ckpt = None\n  # Global step is needed for Adam for hyper parameter conversion.\n  if opt_name_v1 == 'Adam':\n    global_step = reader.get_tensor('global_step')\n  if opt_name_v1 in HP_IN_CKPT:\n    hp_ckpt = HP_IN_CKPT[opt_name_v1]\n  opt_name_v2 = OPT_NAME_V1_TO_V2[opt_name_v1]\n\n  # For variables with equivalent mapping in checkpoint. There are three types:\n  # 1) Hyper parameters. This is mainly for Adam optimizer.\n  # 2) Optimizer variables.\n  # 3) Model variables.\n  for var_name in variable_names:\n    # If a hyper parameter variable is in the checkpoint.\n    if hp_ckpt and any(hp_name in var_name for hp_name in hp_ckpt):\n      for hp_name in hp_ckpt:\n        if hp_name in var_name:\n          var_name_v2 = hp_ckpt[hp_name]\n          tensor = reader.get_tensor(var_name)\n          # For Adam optimizer, in the old checkpoint, the optimizer variables\n          # are beta1_power and beta2_power. The corresponding variables in the\n          # new checkpoint are beta_1 and beta_2, and\n          # beta_1 = pow(beta1_power, 1/global_step)\n          # beta_2 = pow(beta2_power, 1/global_step)\n          tensor = tf.math.pow(tensor, 1.0 / global_step)\n          _add_new_variable(tensor, var_name_v2, var_name, var_map,\n                            var_names_map)\n          break\n    # If it's an optimizer variable.\n    elif opt_name_v1 in var_name:\n      suffix_mapping = OPT_VAR_NAME_V1_TO_V2[opt_name_v1]\n      suffix_v1 = var_name.rsplit('/')[-1]\n      suffix_v2 = suffix_mapping[suffix_v1]\n      if suffix_v2:\n        # For DNN model.\n        if est_type == 'dnn':\n          # The optimizer variable of DNN model in TF 1.x has 't_0' in its\n          # name (b/131719899). This is amended in TF 2.0.\n          idx = var_name.rfind('t_0')\n          _add_opt_variable(opt_name_v2, var_name, idx, suffix_v2, reader,\n                            var_map, var_names_map)\n        # for Linear model.\n        elif est_type == 'linear':\n          # The optimizer variable of Linear model in TF 1.x has 'part_0' in its\n          # name (b/131719899). This is amended in TF 2.0.\n          idx = var_name.rfind('part_0')\n          _add_opt_variable(opt_name_v2, var_name, idx, suffix_v2, reader,\n                            var_map, var_names_map)\n        # for DNNLinearCombined model.\n        else:\n          idx = var_name.rfind(suffix_v1)\n          _add_opt_variable(opt_name_v2, var_name, idx, suffix_v2, reader,\n                            var_map, var_names_map)\n    # If it's a model variable which is already backward compatible.\n    else:\n      tensor = reader.get_tensor(var_name)\n      _add_new_variable(tensor, var_name, var_name, var_map, var_names_map)\n\n\ndef _convert_hyper_params_in_graph(graph_from_path, opt_name_v1, var_map,\n                                   var_names_map):\n  \"\"\"Generates hyper parameters for optimizer v2 from graph.pbtxt.\"\"\"\n  with tf.io.gfile.GFile(graph_from_path) as f:\n    graph_def = text_format.Parse(f.read(), tf.compat.v1.GraphDef())\n\n  # In keras optimizer, the hyper parameters are also stored in the checkpoint,\n  # while v1 checkpoint doesn't contain any hyper parameters. For the\n  # hyper parameter variables, there are two cases:\n  # 1) The hyper parameter exist in the graph.\n  #    If so, the hyper parameter value needs to be extracted from the graph\n  #    node.\n  # 2) The hyper parameter doesn't exist in the graph.\n  #    The value of the hyper parameter is set as the default value from the\n  #    config.\n  nodes_full = HP_IN_GRAPH[opt_name_v1]\n  nodes_in_graph = []\n  opt_name_v2 = OPT_NAME_V1_TO_V2[opt_name_v1]\n  tf.compat.v1.logging.info('For hyper parameter variables that are in Graph:')\n  for node in graph_def.node:\n    node_name = node.name.rsplit('/')[-1]\n    # For case 1), if the hyper parameter of the keras optimizer can be found\n    # in the graph, the graph node value is extracted as the hyper parameter\n    # variable value, and added to the new variable list.\n    if opt_name_v1 + '/' + node_name in nodes_full:\n      hp_value = node.attr.get('value').tensor.float_val[0]\n      hp_name_v2 = 'training/' + opt_name_v2 + '/' + node_name\n      tf.compat.v1.logging.info(\n          'Hyper parameter {} with value {} found in Graph.'.format(\n              hp_name_v2, hp_value))\n      _add_new_variable(hp_value, hp_name_v2, node_name, var_map, var_names_map)\n      # Adds this node to nodes_in_graph\n      nodes_in_graph.append(node_name)\n\n  # For case 2), if the hyper parameter is not in graph, we need to add it\n  # manually. The tensor value is its default value from optimizer v2 config.\n  nodes_not_in_graph = sorted(list(set(nodes_full) - set(nodes_in_graph)))\n  opt_v2_config = OPT_V2_INSTANCE[opt_name_v1].get_config()\n  tf.compat.v1.logging.info(\n      'For hyper parameter variables that are NOT in Graph:')\n  for node_name in nodes_not_in_graph:\n    hp_name_v2 = 'training/' + opt_name_v2 + '/' + node_name\n    tf.compat.v1.logging.info(\n        'Hyper parameter {} with default value {} is added.'.format(\n            hp_name_v2, opt_v2_config[node_name]))\n    _add_new_variable(opt_v2_config[node_name], hp_name_v2, node_name, var_map,\n                      var_names_map)\n\n\ndef convert_checkpoint(estimator_type, source_checkpoint, source_graph,\n                       target_checkpoint):\n  \"\"\"Converts checkpoint from TF 1.x to TF 2.0 for CannedEstimator.\n\n  Args:\n    estimator_type: The type of estimator to be converted. So far, the allowed\n      args include 'dnn', 'linear', and 'combined'.\n    source_checkpoint: Path to the source checkpoint file to be read in.\n    source_graph: Path to the source graph file to be read in.\n    target_checkpoint: Path to the target checkpoint to be written out.\n  \"\"\"\n  with tf.Graph().as_default():\n    # Get v1 optimizer names and it's corresponding variable name\n    reader = tf.compat.v1.train.NewCheckpointReader(source_checkpoint)\n    variable_names = sorted(reader.get_variable_to_shape_map())\n    opt_names_v1 = {}\n    for var_name in variable_names:\n      for opt_name in OPT_NAME_V1_TO_V2:\n        if opt_name in var_name:\n          opt_names_v1[opt_name] = var_name\n\n    # SGD doesn't appear in optimizer variables, so we need to add it manually\n    # if no optimizer is found in checkpoint for DNN or Linear model.\n    if not opt_names_v1:\n      if estimator_type == 'dnn' or estimator_type == 'linear':\n        opt_names_v1['SGD'] = ''\n      # As the case is not handled in the converter if dnn_optimizer and\n      # linear_optimizer in DNNLinearCombined model are the same, an error is\n      # is raised if two SGD optimizers are used in DNNLinearCombined model.\n      elif estimator_type == 'combined':\n        raise ValueError('Two `SGD` optimizers are used in DNNLinearCombined '\n                         'model, and this is not handled by the checkpoint '\n                         'converter.')\n\n    # A dict mapping from v2 variable name to the v2 variable.\n    var_map = {}\n    # A dict mapping from v2 variable name to v1 variable name.\n    var_names_map = {}\n\n    # Determine the names of dnn_optimizer and linear_optimizer in\n    # DNNLinearCombined model.\n    if estimator_type == 'combined':\n      linear_opt_v1 = None\n      if len(opt_names_v1) == 1:  # When one of the optimizer is 'SGD'.\n        key = list(opt_names_v1.keys())[0]\n        # Case 1: linear_optimizer is non-SGD, and dnn_optimizer is SGD.\n        if opt_names_v1[key].startswith('linear/linear_model/'):\n          linear_opt_v1 = key\n        # Case 2: linear_optimizer is SGD, and dnn_optimizer is non-SGD.\n        if not linear_opt_v1:\n          linear_opt_v1 = 'SGD'\n        opt_names_v1['SGD'] = ''\n      else:  # two non-SGD optimizers\n        for key in opt_names_v1:\n          if opt_names_v1[key].startswith('linear/linear_model/'):\n            linear_opt_v1 = key\n      # Add the 'iter' hyper parameter to the new checkpoint for\n      # linear_optimizer. Note dnn_optimizer uses global_step.\n      tensor = reader.get_tensor('global_step')\n      var_name_v2 = 'training/' + OPT_NAME_V1_TO_V2[linear_opt_v1] + '/iter'\n      var_name_v1 = 'global_step'\n      _add_new_variable(tensor, var_name_v2, var_name_v1, var_map,\n                        var_names_map)\n\n    for opt_name_v1 in opt_names_v1:\n      # Convert all existing variables from checkpoint.\n      _convert_variables_in_ckpt(opt_name_v1, reader, variable_names, var_map,\n                                 var_names_map, estimator_type)\n      # Convert hyper parameters for optimizer v2 from the graph.\n      _convert_hyper_params_in_graph(source_graph, opt_name_v1, var_map,\n                                     var_names_map)\n\n    # Log the variable mapping from opt v1 to v2.\n    tf.compat.v1.logging.info(\n        '<----- Variable names converted (v1 --> v2): ----->')\n    for name_v2 in var_names_map:\n      tf.compat.v1.logging.info('%s --> %s' % (var_names_map[name_v2], name_v2))\n\n    # Save to checkpoint v2.\n    saver = tf.compat.v1.train.Saver(var_list=var_map)\n    with tf.compat.v1.Session() as sess:\n      sess.run(tf.compat.v1.initializers.global_variables())\n      tf.compat.v1.logging.info('Writing checkpoint_to_path %s' %\n                                target_checkpoint)\n      saver.save(sess, target_checkpoint)\n\n\ndef main(_):\n  convert_checkpoint(\n      FLAGS.estimator_type,\n      FLAGS.source_checkpoint,\n      FLAGS.source_graph,\n      FLAGS.target_checkpoint,\n  )\n\n\nif __name__ == '__main__':\n  parser = argparse.ArgumentParser()\n  parser.add_argument(\n      'estimator_type',\n      type=str,\n      choices=['dnn', 'linear', 'combined'],\n      help='The type of estimator to be converted. So far, the checkpoint '\n      'converter only supports Canned Estimator. So the allowed types '\n      'include linear, dnn and combined.')\n  parser.add_argument(\n      'source_checkpoint',\n      type=str,\n      help='Path to source checkpoint file to be read in.')\n  parser.add_argument(\n      'source_graph', type=str, help='Path to source graph file to be read in.')\n  parser.add_argument(\n      'target_checkpoint',\n      type=str,\n      help='Path to checkpoint file to be written out.')\n  FLAGS, unparsed = parser.parse_known_args()\n  tf.compat.v1.app.run(main=main, argv=[sys.argv[0]] + unparsed)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tools/checkpoint_converter_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for checkpoint_converter.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport os\nimport shutil\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.feature_column import feature_column\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import dnn_linear_combined\nfrom tensorflow_estimator.python.estimator.canned import head as head_lib\nfrom tensorflow_estimator.python.estimator.canned import linear\nfrom tensorflow_estimator.python.estimator.head import regression_head\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.tools import checkpoint_converter\n\n\nclass DNNCheckpointConverterTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._old_ckpt_dir = os.path.join(self.get_temp_dir(), 'source_ckpt')\n    self._new_ckpt_dir = os.path.join(self.get_temp_dir(), 'target_ckpt')\n\n  def tearDown(self):\n    if os.path.exists(self._old_ckpt_dir):\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._old_ckpt_dir)\n    if os.path.exists(self._new_ckpt_dir):\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._new_ckpt_dir)\n\n  def _test_ckpt_converter(self, train_input_fn, eval_input_fn,\n                           predict_input_fn, input_dimension, label_dimension,\n                           batch_size, optimizer):\n\n    # Create checkpoint in CannedEstimator v1.\n    feature_columns_v1 = [\n        feature_column._numeric_column('x', shape=(input_dimension,))\n    ]\n\n    est_v1 = dnn.DNNEstimator(\n        head=head_lib._regression_head(label_dimension=label_dimension),\n        hidden_units=(2, 2),\n        feature_columns=feature_columns_v1,\n        model_dir=self._old_ckpt_dir,\n        optimizer=optimizer)\n    # Train\n    num_steps = 10\n    est_v1.train(train_input_fn, steps=num_steps)\n    self.assertIsNotNone(est_v1.latest_checkpoint())\n    self.assertTrue(est_v1.latest_checkpoint().startswith(self._old_ckpt_dir))\n\n    # Convert checkpoint from v1 to v2.\n    source_checkpoint = os.path.join(self._old_ckpt_dir, 'model.ckpt-10')\n    source_graph = os.path.join(self._old_ckpt_dir, 'graph.pbtxt')\n    target_checkpoint = os.path.join(self._new_ckpt_dir, 'model.ckpt-10')\n    checkpoint_converter.convert_checkpoint('dnn', source_checkpoint,\n                                            source_graph, target_checkpoint)\n\n    # Create CannedEstimator V2 and restore from the converted checkpoint.\n    feature_columns_v2 = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est_v2 = dnn.DNNEstimatorV2(\n        head=regression_head.RegressionHead(label_dimension=label_dimension),\n        hidden_units=(2, 2),\n        feature_columns=feature_columns_v2,\n        model_dir=self._new_ckpt_dir,\n        optimizer=optimizer)\n    # Train\n    extra_steps = 10\n    est_v2.train(train_input_fn, steps=extra_steps)\n    self.assertIsNotNone(est_v2.latest_checkpoint())\n    self.assertTrue(est_v2.latest_checkpoint().startswith(self._new_ckpt_dir))\n    # Make sure estimator v2 restores from the converted checkpoint, and\n    # continues training extra steps.\n    self.assertEqual(\n        num_steps + extra_steps,\n        est_v2.get_variable_value(tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n  def _create_input_fn(self, label_dimension, batch_size):\n    \"\"\"Creates input_fn for integration test.\"\"\"\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    return train_input_fn, eval_input_fn, predict_input_fn\n\n  def _test_ckpt_converter_with_an_optimizer(self, opt):\n    \"\"\"Tests checkpoint converter with an optimizer.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    train_input_fn, eval_input_fn, predict_input_fn = self._create_input_fn(\n        label_dimension, batch_size)\n\n    self._test_ckpt_converter(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        optimizer=opt)\n\n  def test_ckpt_converter_with_adagrad(self):\n    \"\"\"Tests checkpoint converter with Adagrad.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('Adagrad')\n\n  def test_ckpt_converter_with_rmsprop(self):\n    \"\"\"Tests checkpoint converter with RMSProp.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('RMSProp')\n\n  def test_ckpt_converter_with_ftrl(self):\n    \"\"\"Tests checkpoint converter with Ftrl.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('Ftrl')\n\n  def test_ckpt_converter_with_adam(self):\n    \"\"\"Tests checkpoint converter with Adam.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('Adam')\n\n  def test_ckpt_converter_with_sgd(self):\n    \"\"\"Tests checkpoint converter with SGD.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('SGD')\n\n\nclass LinearCheckpointConverterTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._old_ckpt_dir = os.path.join(self.get_temp_dir(), 'source_ckpt')\n    self._new_ckpt_dir = os.path.join(self.get_temp_dir(), 'target_ckpt')\n\n  def tearDown(self):\n    if os.path.exists(self._old_ckpt_dir):\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._old_ckpt_dir)\n    if os.path.exists(self._new_ckpt_dir):\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._new_ckpt_dir)\n\n  def _test_ckpt_converter(self, train_input_fn, eval_input_fn,\n                           predict_input_fn, input_dimension, label_dimension,\n                           batch_size, optimizer):\n\n    # Create checkpoint in CannedEstimator v1.\n    feature_columns_v1 = [\n        feature_column._numeric_column('x', shape=(input_dimension,))\n    ]\n\n    est_v1 = linear.LinearEstimator(\n        head=head_lib._regression_head(label_dimension=label_dimension),\n        feature_columns=feature_columns_v1,\n        model_dir=self._old_ckpt_dir,\n        optimizer=optimizer)\n    # Train\n    num_steps = 10\n    est_v1.train(train_input_fn, steps=num_steps)\n    self.assertIsNotNone(est_v1.latest_checkpoint())\n    self.assertTrue(est_v1.latest_checkpoint().startswith(self._old_ckpt_dir))\n\n    # Convert checkpoint from v1 to v2.\n    source_checkpoint = os.path.join(self._old_ckpt_dir, 'model.ckpt-10')\n    source_graph = os.path.join(self._old_ckpt_dir, 'graph.pbtxt')\n    target_checkpoint = os.path.join(self._new_ckpt_dir, 'model.ckpt-10')\n    checkpoint_converter.convert_checkpoint('linear', source_checkpoint,\n                                            source_graph, target_checkpoint)\n\n    # Create CannedEstimator V2 and restore from the converted checkpoint.\n    feature_columns_v2 = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est_v2 = linear.LinearEstimatorV2(\n        head=regression_head.RegressionHead(label_dimension=label_dimension),\n        feature_columns=feature_columns_v2,\n        model_dir=self._new_ckpt_dir,\n        optimizer=optimizer)\n    # Train\n    extra_steps = 10\n    est_v2.train(train_input_fn, steps=extra_steps)\n    self.assertIsNotNone(est_v2.latest_checkpoint())\n    self.assertTrue(est_v2.latest_checkpoint().startswith(self._new_ckpt_dir))\n    # Make sure estimator v2 restores from the converted checkpoint, and\n    # continues training extra steps.\n    self.assertEqual(\n        num_steps + extra_steps,\n        est_v2.get_variable_value(tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n  def _create_input_fn(self, label_dimension, batch_size):\n    \"\"\"Creates input_fn for integration test.\"\"\"\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    return train_input_fn, eval_input_fn, predict_input_fn\n\n  def _test_ckpt_converter_with_an_optimizer(self, opt):\n    \"\"\"Tests checkpoint converter with an optimizer.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    train_input_fn, eval_input_fn, predict_input_fn = self._create_input_fn(\n        label_dimension, batch_size)\n\n    self._test_ckpt_converter(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        optimizer=opt)\n\n  def test_ckpt_converter_with_adagrad(self):\n    \"\"\"Tests checkpoint converter with Adagrad.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('Adagrad')\n\n  def test_ckpt_converter_with_rmsprop(self):\n    \"\"\"Tests checkpoint converter with RMSProp.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('RMSProp')\n\n  def test_ckpt_converter_with_ftrl(self):\n    \"\"\"Tests checkpoint converter with Ftrl.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('Ftrl')\n\n  def test_ckpt_converter_with_adam(self):\n    \"\"\"Tests checkpoint converter with Adam.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('Adam')\n\n  def test_ckpt_converter_with_sgd(self):\n    \"\"\"Tests checkpoint converter with SGD.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('SGD')\n\n\nclass DNNLinearCombinedCheckpointConverterTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._old_ckpt_dir = os.path.join(self.get_temp_dir(), 'source_ckpt')\n    self._new_ckpt_dir = os.path.join(self.get_temp_dir(), 'target_ckpt')\n\n  def tearDown(self):\n    if os.path.exists(self._old_ckpt_dir):\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._old_ckpt_dir)\n    if os.path.exists(self._new_ckpt_dir):\n      tf.compat.v1.summary.FileWriterCache.clear()\n      shutil.rmtree(self._new_ckpt_dir)\n\n  def _test_ckpt_converter(self, train_input_fn, eval_input_fn,\n                           predict_input_fn, input_dimension, label_dimension,\n                           batch_size, dnn_optimizer, linear_optimizer):\n\n    # Create checkpoint in CannedEstimator v1.\n    linear_feature_columns_v1 = [\n        feature_column._numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns_v1 = [\n        feature_column._numeric_column('x', shape=(input_dimension,))\n    ]\n    est_v1 = dnn_linear_combined.DNNLinearCombinedEstimator(\n        head=head_lib._regression_head(label_dimension=label_dimension),\n        linear_feature_columns=linear_feature_columns_v1,\n        dnn_feature_columns=dnn_feature_columns_v1,\n        dnn_hidden_units=(2, 2),\n        model_dir=self._old_ckpt_dir,\n        dnn_optimizer=dnn_optimizer,\n        linear_optimizer=linear_optimizer)\n    # Train\n    num_steps = 10\n    est_v1.train(train_input_fn, steps=num_steps)\n    self.assertIsNotNone(est_v1.latest_checkpoint())\n    self.assertTrue(est_v1.latest_checkpoint().startswith(self._old_ckpt_dir))\n\n    # Convert checkpoint from v1 to v2.\n    source_checkpoint = os.path.join(self._old_ckpt_dir, 'model.ckpt-10')\n    source_graph = os.path.join(self._old_ckpt_dir, 'graph.pbtxt')\n    target_checkpoint = os.path.join(self._new_ckpt_dir, 'model.ckpt-10')\n    checkpoint_converter.convert_checkpoint('combined', source_checkpoint,\n                                            source_graph, target_checkpoint)\n\n    # Create CannedEstimator V2 and restore from the converted checkpoint.\n    linear_feature_columns_v2 = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    dnn_feature_columns_v2 = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n    est_v2 = dnn_linear_combined.DNNLinearCombinedEstimatorV2(\n        head=regression_head.RegressionHead(label_dimension=label_dimension),\n        linear_feature_columns=linear_feature_columns_v2,\n        dnn_feature_columns=dnn_feature_columns_v2,\n        dnn_hidden_units=(2, 2),\n        model_dir=self._new_ckpt_dir,\n        dnn_optimizer=dnn_optimizer,\n        linear_optimizer=linear_optimizer)\n    # Train\n    extra_steps = 10\n    est_v2.train(train_input_fn, steps=extra_steps)\n    self.assertIsNotNone(est_v2.latest_checkpoint())\n    self.assertTrue(est_v2.latest_checkpoint().startswith(self._new_ckpt_dir))\n    # Make sure estimator v2 restores from the converted checkpoint, and\n    # continues training extra steps.\n    self.assertEqual(\n        num_steps + extra_steps,\n        est_v2.get_variable_value(tf.compat.v1.GraphKeys.GLOBAL_STEP))\n\n  def _create_input_fn(self, label_dimension, batch_size):\n    \"\"\"Creates input_fn for integration test.\"\"\"\n    data = np.linspace(0., 2., batch_size * label_dimension, dtype=np.float32)\n    data = data.reshape(batch_size, label_dimension)\n    # learn y = x\n    train_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data},\n        y=data,\n        batch_size=batch_size,\n        num_epochs=None,\n        shuffle=True)\n    eval_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, y=data, batch_size=batch_size, shuffle=False)\n    predict_input_fn = numpy_io.numpy_input_fn(\n        x={'x': data}, batch_size=batch_size, shuffle=False)\n\n    return train_input_fn, eval_input_fn, predict_input_fn\n\n  def _test_ckpt_converter_with_an_optimizer(self, dnn_opt, linear_opt):\n    \"\"\"Tests checkpoint converter with an optimizer.\"\"\"\n    label_dimension = 2\n    batch_size = 10\n    train_input_fn, eval_input_fn, predict_input_fn = self._create_input_fn(\n        label_dimension, batch_size)\n\n    self._test_ckpt_converter(\n        train_input_fn=train_input_fn,\n        eval_input_fn=eval_input_fn,\n        predict_input_fn=predict_input_fn,\n        input_dimension=label_dimension,\n        label_dimension=label_dimension,\n        batch_size=batch_size,\n        dnn_optimizer=dnn_opt,\n        linear_optimizer=linear_opt)\n\n  def test_ckpt_converter_with_adagrad(self):\n    \"\"\"Tests checkpoint converter with Adagrad.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('Adagrad', 'RMSProp')\n\n  def test_ckpt_converter_with_rmsprop(self):\n    \"\"\"Tests checkpoint converter with RMSProp.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('RMSProp', 'Ftrl')\n\n  def test_ckpt_converter_with_ftrl(self):\n    \"\"\"Tests checkpoint converter with Ftrl.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('Ftrl', 'Adam')\n\n  def test_ckpt_converter_with_adam(self):\n    \"\"\"Tests checkpoint converter with Adam.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('Adam', 'SGD')\n\n  def test_ckpt_converter_with_sgd(self):\n    \"\"\"Tests checkpoint converter with SGD.\"\"\"\n    self._test_ckpt_converter_with_an_optimizer('SGD', 'Adagrad')\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/BUILD",
    "content": "# Description: TPUEstimator\n\n# Placeholder: load py_library\n\n# INTERNAL TEST RULE PLACEHOLDER\nload(\"//tensorflow_estimator:estimator.bzl\", \"py_test\", \"tpu_py_test\")\n\nlicenses([\"notice\"])\n\npackage(\n    default_visibility = [\n        \"//tensorflow_estimator:internal\",\n        \"//third_party/tensorflow:__subpackages__\",\n    ],\n)\n\npy_library(\n    name = \"tpu_estimator\",\n    srcs = [\n        \"_tpu_estimator_embedding.py\",\n        \"error_handling.py\",\n        \"iteration_count_estimator.py\",\n        \"tpu_config.py\",\n        \"tpu_context.py\",\n        \"tpu_estimator.py\",\n        \"util.py\",\n    ],\n    srcs_version = \"PY3\",\n    deps = [\n        \"//tensorflow_estimator/python/estimator\",\n        \"//tensorflow_estimator/python/estimator:analytics_tools\",\n        \"//tensorflow_estimator/python/estimator:estimator_export\",\n        \"//tensorflow_estimator/python/estimator:expect_six_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//tensorflow_estimator/python/estimator:export_output\",\n        \"//tensorflow_estimator/python/estimator:model_fn\",\n        \"//tensorflow_estimator/python/estimator:run_config\",\n    ],\n)\n\npy_test(\n    name = \"tpu_config_test\",\n    size = \"small\",\n    srcs = [\"tpu_config_test.py\"],\n    python_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"error_handling_test\",\n    size = \"small\",\n    srcs = [\"error_handling_test.py\"],\n    python_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n    ],\n)\n\npy_test(\n    name = \"tpu_estimator_signals_test\",\n    size = \"small\",\n    srcs = [\"tpu_estimator_signals_test.py\"],\n    python_version = \"PY3\",\n    # TODO(jhseu): Remove. Fails in OSS on Python 3.\n    tags = [\n        \"no_oss\",\n    ],\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\ntpu_py_test(\n    name = \"tpu_estimator_test\",\n    size = \"medium\",\n    timeout = \"long\",\n    srcs = [\"tpu_estimator_test.py\"],\n    args = [\n        \"--test_num_shards=2\",\n    ],\n    disable_experimental = True,\n    shard_count = 2,\n    srcs_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//third_party/py/absl/flags\",\n    ],\n)\n\ntpu_py_test(\n    name = \"tpu_estimator_embedding_test\",\n    size = \"medium\",\n    timeout = \"long\",\n    srcs = [\n        \"tpu_estimator_embedding_test.py\",\n    ],\n    args = [\n        \"--test_num_shards=2\",\n    ],\n    # TODO(b/140117863): Hanging, then timeout\n    disable_experimental = True,\n    shard_count = 4,\n    srcs_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//third_party/py/absl/flags\",\n    ],\n)\n\ntpu_py_test(\n    name = \"tpu_estimator_evaluation_test\",\n    size = \"medium\",\n    timeout = \"long\",\n    srcs = [\"tpu_estimator_evaluation_test.py\"],\n    args = [\n        \"--test_num_shards=2\",\n    ],\n    disable_experimental = True,\n    shard_count = 2,\n    srcs_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//third_party/py/absl/flags\",\n    ],\n)\n\ntpu_py_test(\n    name = \"tpu_estimator_export_test\",\n    size = \"medium\",\n    srcs = [\"tpu_estimator_export_test.py\"],\n    args = [\n        \"--test_num_shards=2\",\n    ],\n    disable_experimental = True,\n    shard_count = 2,\n    srcs_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\ntpu_py_test(\n    name = \"tpu_estimator_gradients_test\",\n    size = \"medium\",\n    srcs = [\n        \"tpu_estimator_gradients_test.py\",\n    ],\n    args = [\n        \"--test_num_shards=2\",\n        \"--xla_jf_conv_full_precision=true\",\n    ],\n    # TODO(b/140117863): Fatal error from hardware\n    disable_experimental = True,\n    disable_mlir_bridge = False,\n    shard_count = 2,\n    srcs_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\ntpu_py_test(\n    name = \"tpu_estimator_input_v2_test\",\n    size = \"medium\",\n    srcs = [\"tpu_estimator_input_v2_test.py\"],\n    disable_experimental = True,\n    srcs_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\ntpu_py_test(\n    name = \"tpu_estimator_integration_test\",\n    size = \"medium\",\n    srcs = [\"tpu_estimator_integration_test.py\"],\n    args = [\n        \"--test_num_shards=2\",\n    ],\n    disable_experimental = True,\n    srcs_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\ntpu_py_test(\n    name = \"tpu_estimator_model_parallelism_test\",\n    size = \"medium\",\n    srcs = [\"tpu_estimator_model_parallelism_test.py\"],\n    args = [\n    ],\n    disable_experimental = True,\n    srcs_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\npy_test(\n    name = \"autotuning_iterations_per_loop_test\",\n    size = \"small\",\n    srcs = [\"autotuning_iterations_per_loop_test.py\"],\n    python_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n    ],\n)\n\ntpu_py_test(\n    name = \"tpu_enqueue_sequence_test\",\n    size = \"medium\",\n    srcs = [\"tpu_enqueue_sequence_test.py\"],\n    disable_experimental = True,\n    python_version = \"PY3\",\n    srcs_version = \"PY3\",\n    deps = [\n        \":tpu_estimator\",\n        \"//tensorflow_estimator/python/estimator:expect_absl_installed\",\n        \"//tensorflow_estimator/python/estimator:expect_tensorflow_installed\",\n        \"//third_party/tensorflow/contrib/summary\",\n    ],\n)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/__init__.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =============================================================================\n\"\"\"TPUEstimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/_tpu_estimator_embedding.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ===================================================================\n\"\"\"Tooling for support TPU embedding in TPUEstimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport tensorflow as tf\n\nfrom tensorflow.python.feature_column import feature_column as core_fc\nfrom tensorflow.python.feature_column import feature_column_lib as core_fc_lib\nfrom tensorflow.python.feature_column import utils as fc_utils\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.tpu import feature_column as tpu_fc\nfrom tensorflow.python.tpu import feature_column_v2 as tpu_fc_v2\nfrom tensorflow.python.tpu import tpu_embedding\nfrom tensorflow.python.tpu.tpu_embedding import AdagradParameters\nfrom tensorflow.python.tpu.tpu_embedding import AdamParameters\nfrom tensorflow.python.tpu.tpu_embedding import FtrlParameters\nfrom tensorflow.python.tpu.tpu_embedding import MomentumParameters\nfrom tensorflow.python.tpu.tpu_embedding import ProximalAdagradParameters\nfrom tensorflow.python.tpu.tpu_embedding import RMSPropParameters\nfrom tensorflow.python.tpu.tpu_embedding import StochasticGradientDescentParameters\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\n# pylint: disable=protected-access\n_TPU_EMBEDDING_COLUMN_CLASSES = (tpu_fc._TPUEmbeddingColumn,\n                                 tpu_fc._TPUSharedEmbeddingColumn,\n                                 tpu_fc_v2._TPUEmbeddingColumnV2,\n                                 tpu_fc_v2._TPUSharedEmbeddingColumnV2)\n_TPU_DEVICE_SPECIFIC_EMBEDDING_COLUMNS = (\n    tpu_fc_v2._TPUDeviceSpecificEmbeddingColumnV2,\n    tpu_fc_v2._TPUSharedDeviceSpecificEmbeddingColumnV2)\n_EMBEDDING_COLUMN_CLASSES = (core_fc._EmbeddingColumn,\n                             core_fc_lib.EmbeddingColumn,\n                             core_fc._SharedEmbeddingColumn)\n_SUPPORTED_FEATURE_COLUMNS = (core_fc._NumericColumn, core_fc_lib.NumericColumn)\n\n_SUPPORTED_OPTIMIZERS = (\n    ProximalAdagradParameters,\n    AdagradParameters,\n    AdamParameters,\n    FtrlParameters,\n    StochasticGradientDescentParameters,\n    MomentumParameters,\n    RMSPropParameters,\n)\n\n# pylint: enable=protected-access\n\n_TABLE_NAME_PREFIX = 'tbl_'\n_LEN_TABLE_NAME_PREFIX = len(_TABLE_NAME_PREFIX)\n\n\ndef _get_table_name_from_embedding_var_name(embedding_var_name):\n  return '{}{}'.format(_TABLE_NAME_PREFIX, embedding_var_name)\n\n\ndef _get_embedding_var_name_from_table_name(table_name):\n  return table_name[_LEN_TABLE_NAME_PREFIX:]\n\n\ndef _get_embedding_variable_name(scope_name, var_name):\n  if scope_name:\n    scope_name = scope_name + '/'\n  return '{}{}'.format(scope_name, var_name)\n\n\ndef _get_slot_variable_names(scope_name, var_name, optimization_parameters):\n  \"\"\"Return embedding variable names which are consistent with CPU runs.\"\"\"\n  if scope_name:\n    scope_name = scope_name + '/'\n  if isinstance(optimization_parameters,\n                tf.compat.v1.tpu.experimental.AdagradParameters):\n    return tpu_embedding.AdagradSlotVariableNames('{}{}/Adagrad'.format(\n        scope_name, var_name))\n  elif isinstance(optimization_parameters,\n                  tf.compat.v1.tpu.experimental.AdamParameters):\n    return tpu_embedding.AdamSlotVariableNames(\n        '{}{}/Adam/m'.format(scope_name, var_name),\n        '{}{}/Adam/v'.format(scope_name, var_name))\n  elif isinstance(optimization_parameters,\n                  tf.compat.v1.tpu.experimental.FtrlParameters):\n    return tpu_embedding.FtrlSlotVariableNames(\n        '{}{}/Ftrl'.format(scope_name, var_name),  # accumulator\n        '{}{}/Ftrl_1'.format(scope_name, var_name))  # linear\n  elif isinstance(optimization_parameters, MomentumParameters):\n    return tpu_embedding.MomentumSlotVariableNames('{}{}/Momentum'.format(\n        scope_name, var_name))\n  elif isinstance(optimization_parameters, RMSPropParameters):\n    return tpu_embedding.RMSPropSlotVariableNames(\n        ms='{}{}/RMSProp/ms'.format(scope_name, var_name),\n        mom='{}{}/RMSProp/mom'.format(scope_name, var_name),\n    )\n  elif isinstance(optimization_parameters, ProximalAdagradParameters):\n    return tpu_embedding.ProximalAdagradSlotVariableNames(\n        '{}{}/ProximalAdagrad'.format(scope_name, var_name))\n  elif isinstance(\n      optimization_parameters,\n      tf.compat.v1.tpu.experimental.StochasticGradientDescentParameters):\n    return None\n  else:\n    raise ValueError('Support to infer full variable name '\n                     'for optimization_parameter {} has not been added.'.format(\n                         optimization_parameters))\n\n\ndef get_full_variable_names(graph,\n                            table_to_config_dict,\n                            optimization_parameters=None):\n  \"\"\"Return embedding variable names and slot variables which are consistent with CPU runs.\"\"\"\n  collection = graph.get_collection_ref(tpu_fc._TPU_FC_TO_SCOPE)  # pylint: disable=protected-access\n  if not collection:\n    raise RuntimeError(\n        'Embedding feature column did not capture any thing. Make sure the '\n        'feature columns passed to TPUEstimator constructor is properly '\n        'used in model_fn.')\n\n  embedding_variable_name_by_table = {}\n  slot_variable_names_by_table = {}\n  for table_name in table_to_config_dict:\n    embedding_var_name = _get_embedding_var_name_from_table_name(table_name)\n    (scope_name, var_name) = collection[0][embedding_var_name]\n    embedding_variable_name_by_table[table_name] = (\n        _get_embedding_variable_name(scope_name, var_name))\n    if optimization_parameters:\n      slot_variable_names_by_table[table_name] = _get_slot_variable_names(\n          scope_name, var_name, optimization_parameters)\n\n  graph.clear_collection(tpu_fc._TPU_FC_TO_SCOPE)  # pylint: disable=protected-access\n  return embedding_variable_name_by_table, slot_variable_names_by_table\n\n\ndef get_configs_from_feature_columns(feature_columns):\n  \"\"\"Create configs for TPUEmbedding etc from a list of feature columns.\n\n  Args:\n    feature_columns: a list of supported feature columns.\n\n  Returns:\n    A tuple of dicts, the first maps tables to their config, the second maps\n    features to their config, the third maps learning rate key to callback that\n    takes global step and outputs dynamic learning rate.\n  \"\"\"\n\n  allowed = (\n      tpu_fc_v2._TPUEmbeddingColumnV2,  # pylint: disable=protected-access\n      tpu_fc_v2._TPUSharedEmbeddingColumnV2)  # pylint: disable=protected-access\n  warn = (tpu_fc._TPUEmbeddingColumn, tpu_fc._TPUSharedEmbeddingColumn)  # pylint: disable=protected-access\n\n  for column in feature_columns:\n    if not isinstance(column, allowed + warn):\n      raise TypeError(\n          'Unsupported feature column {}. Supported types are {}.'.format(\n              type(column), allowed))\n    if isinstance(column, warn):\n      tf.compat.v1.logging.warn(\n          'Columns of type {} are deprecated. Supported types are {}.'.format(\n              type(column), allowed))\n\n  table_to_config = {}\n  feature_to_config = {}\n  for column in feature_columns:\n    feature_name = column.get_feature_key_name()\n    table_name = _get_table_name_from_embedding_var_name(\n        column.get_embedding_var_name())\n    if feature_name in feature_to_config:\n      raise ValueError(\n          'Feature column {} is used with multiple embeddings and this is '\n          'not supported.'.format(feature_name))\n    feature_to_config[feature_name] = tpu_embedding.FeatureConfig(\n        table_id=table_name,\n        max_sequence_length=column.get_max_sequence_length(),\n        weight_key=column.get_weight_key_name())\n    vocabulary_size, dimension = column.get_embedding_table_size()\n    table_to_config[table_name] = tpu_embedding.TableConfig(\n        vocabulary_size=vocabulary_size,\n        dimension=dimension,\n        initializer=column.get_initializer(),\n        combiner=column.get_combiner(),\n        learning_rate_fn=column.get_learning_rate_fn())\n\n  return table_to_config, feature_to_config\n\n\n@estimator_export(v1=['estimator.tpu.experimental.EmbeddingConfigSpec'])\nclass EmbeddingConfigSpec(\n    collections.namedtuple('EmbeddingConfigSpec', [\n        'feature_columns', 'tensor_core_feature_columns',\n        'optimization_parameters', 'clipping_limit',\n        'pipeline_execution_with_tensor_core',\n        'experimental_gradient_multiplier_fn', 'feature_to_config_dict',\n        'table_to_config_dict', 'partition_strategy', 'profile_data_directory'\n    ])):\n  \"\"\"Class to keep track of the specification for TPU embeddings.\n\n  Pass this class to `tf.estimator.tpu.TPUEstimator` via the\n  `embedding_config_spec` parameter. At minimum you need to specify\n  `feature_columns` and `optimization_parameters`. The feature columns passed\n  should be created with some combination of\n  `tf.tpu.experimental.embedding_column` and\n  `tf.tpu.experimental.shared_embedding_columns`.\n\n  TPU embeddings do not support arbitrary Tensorflow optimizers and the\n  main optimizer you use for your model will be ignored for the embedding table\n  variables. Instead TPU embeddigns support a fixed set of predefined optimizers\n  that you can select from and set the parameters of. These include adagrad,\n  adam and stochastic gradient descent. Each supported optimizer has a\n  `Parameters` class in the `tf.tpu.experimental` namespace.\n\n  ```\n  column_a = tf.feature_column.categorical_column_with_identity(...)\n  column_b = tf.feature_column.categorical_column_with_identity(...)\n  column_c = tf.feature_column.categorical_column_with_identity(...)\n  tpu_shared_columns = tf.tpu.experimental.shared_embedding_columns(\n      [column_a, column_b], 10)\n  tpu_non_shared_column = tf.tpu.experimental.embedding_column(\n      column_c, 10)\n  tpu_columns = [tpu_non_shared_column] + tpu_shared_columns\n  ...\n  def model_fn(features):\n    dense_features = tf_keras.layers.DenseFeature(tpu_columns)\n    embedded_feature = dense_features(features)\n    ...\n\n  estimator = tf.estimator.tpu.TPUEstimator(\n      model_fn=model_fn,\n      ...\n      embedding_config_spec=tf.estimator.tpu.experimental.EmbeddingConfigSpec(\n          column=tpu_columns,\n          optimization_parameters=(\n              tf.estimator.tpu.experimental.AdagradParameters(0.1))))\n  ```\n\n  @compatibility(TF2)\n  TPU Estimator manages its own TensorFlow graph and session, so it is not\n  compatible with TF2 behaviors. We recommend that you migrate to the newer\n  `tf.distribute.TPUStrategy`. See the\n  [TPU guide](https://www.tensorflow.org/guide/tpu) for details.\n  @end_compatibility\n  \"\"\"\n\n  def __new__(cls,\n              feature_columns=None,\n              optimization_parameters=None,\n              clipping_limit=None,\n              pipeline_execution_with_tensor_core=False,\n              experimental_gradient_multiplier_fn=None,\n              feature_to_config_dict=None,\n              table_to_config_dict=None,\n              partition_strategy='div',\n              profile_data_directory=None):\n    \"\"\"Creates an `EmbeddingConfigSpec` instance.\n\n    Args:\n      feature_columns: All embedding `FeatureColumn`s used by model.\n      optimization_parameters: An instance of `AdagradParameters`,\n        `AdamParameters` or `StochasticGradientDescentParameters`. This\n        optimizer will be applied to all embedding variables specified by\n        `feature_columns`.\n      clipping_limit: (Optional) Clipping limit (absolute value).\n      pipeline_execution_with_tensor_core: setting this to `True` makes training\n        faster, but trained model will be different if step N and step N+1\n        involve the same set of embedding IDs. Please see\n        `tpu_embedding_configuration.proto` for details.\n      experimental_gradient_multiplier_fn: (Optional) A Fn taking global step as\n        input returning the current multiplier for all embedding gradients.\n      feature_to_config_dict: A dictionary mapping feature names to instances of\n        the class `FeatureConfig`. Either features_columns or the pair of\n        `feature_to_config_dict` and `table_to_config_dict` must be specified.\n      table_to_config_dict: A dictionary mapping feature names to instances of\n        the class `TableConfig`. Either features_columns or the pair of\n        `feature_to_config_dict` and `table_to_config_dict` must be specified.\n      partition_strategy: A string, determining how tensors are sharded to the\n        tpu hosts. See `tf.nn.safe_embedding_lookup_sparse` for more details.\n        Allowed value are `\"div\"` and `\"mod\"'. If `\"mod\"` is used, evaluation\n        and exporting the model to CPU will not work as expected.\n      profile_data_directory: Directory where embedding lookup statistics are\n        stored. These statistics summarize information about the inputs to the\n        embedding lookup operation, in particular, the average number of\n        embedding IDs per example and how well the embedding IDs are load\n        balanced across the system. The lookup statistics are used during TPU\n        initialization for embedding table partitioning. Collection of lookup\n        statistics is done at runtime by  profiling the embedding inputs, only a\n        small fraction of input samples are profiled to minimize host CPU\n        overhead. Once a suitable number of samples are profiled, the lookup\n        statistics are saved to table-specific files in the profile data\n        directory generally at the end of a TPU training loop. The filename\n        corresponding to each table is obtained by hashing table specific\n        parameters (e.g., table name and number of features) and global\n        configuration parameters (e.g., sharding strategy and task count). The\n        same profile data directory can be shared among several models to reuse\n        embedding lookup statistics.\n\n    Returns:\n      An `EmbeddingConfigSpec` instance.\n\n    Raises:\n      ValueError: If the feature_columns are not specified.\n      TypeError: If the feature columns are not of ths correct type (one of\n        _SUPPORTED_FEATURE_COLUMNS, _TPU_EMBEDDING_COLUMN_CLASSES OR\n        _EMBEDDING_COLUMN_CLASSES).\n      ValueError: If `optimization_parameters` is not one of the required types.\n    \"\"\"\n    if (not feature_columns and\n        not (feature_to_config_dict and table_to_config_dict) or\n        (feature_columns and\n         (feature_to_config_dict and table_to_config_dict))):\n      raise ValueError('Exactly one of `feature_columns` and the pair '\n                       '`feature_to_config_dict` and `table_to_config_dict` '\n                       'must be be specified.')\n\n    if partition_strategy not in ('div', 'mod'):\n      raise ValueError('Invalid partition_strategy {}. Must be one of \"mod\" or '\n                       '\"div\".'.format(partition_strategy))\n\n    tensor_core_feature_columns = None\n    embedding_core_feature_columns = None\n    if feature_columns:\n      tensor_core_feature_columns = []\n      embedding_core_feature_columns = []\n      # It is unknown at this moment, whether the TPUEstimator is running in CPU\n      # or TPU mode. So allow non-TPU embedding columns also.\n      supported_classes = tuple(\n          list(_SUPPORTED_FEATURE_COLUMNS) +\n          list(_TPU_EMBEDDING_COLUMN_CLASSES) + list(_EMBEDDING_COLUMN_CLASSES))\n\n      for column in feature_columns:\n        if (isinstance(column, _TPU_DEVICE_SPECIFIC_EMBEDDING_COLUMNS) and\n            (column._embedding_lookup_device ==  # pylint: disable=protected-access\n             tpu_fc_v2.EmbeddingDevice.TPU_TENSOR_CORE)):\n          tensor_core_feature_columns.append(column)\n        else:\n          embedding_core_feature_columns.append(column)\n        if not isinstance(column, supported_classes):\n          raise TypeError(\n              'All feature columns must be supported types in {}. Got {}'\n              .format(supported_classes, type(column)))\n\n      if not isinstance(optimization_parameters, _SUPPORTED_OPTIMIZERS):\n        raise ValueError('optimization_parameters must be an instance of type '\n                         '{}. Got {}.'.format(_SUPPORTED_OPTIMIZERS,\n                                              type(optimization_parameters)))\n    else:\n      for feature, config in feature_to_config_dict.items():\n        if not isinstance(config, tpu_embedding.FeatureConfig):\n          raise TypeError(\n              'Config for feature {} must be of type `FeatureConfig`. Got {}'\n              .format(feature, type(config)))\n        if config.table_id not in table_to_config_dict:\n          raise ValueError('Feature {} refers to table {} which is not in the '\n                           'table_to_config_dict.'.format(\n                               feature, config.table_id))\n      for table, config in table_to_config_dict.items():\n        if not isinstance(config, tpu_embedding.TableConfig):\n          raise TypeError(\n              'Config for table {} must be of type `TableConfig`. Got '\n              '{}'.format(table, type(config)))\n\n    return super(EmbeddingConfigSpec, cls).__new__(\n        cls,\n        feature_columns=embedding_core_feature_columns,\n        tensor_core_feature_columns=tensor_core_feature_columns,\n        optimization_parameters=optimization_parameters,\n        clipping_limit=clipping_limit,\n        pipeline_execution_with_tensor_core=pipeline_execution_with_tensor_core,\n        experimental_gradient_multiplier_fn=experimental_gradient_multiplier_fn,\n        feature_to_config_dict=feature_to_config_dict,\n        table_to_config_dict=table_to_config_dict,\n        partition_strategy=partition_strategy,\n        profile_data_directory=profile_data_directory)\n\n\nclass EmbeddingConfig(object):\n  \"\"\"This is the internal immutable object for embedding config.\n\n  `_EmbeddingConfig` is responsible to _translate_ user provided\n  `EmbeddingConfigSpec` to internal data structures, mostly constructor\n  arguments of `TPUEmbedding`.\n  \"\"\"\n\n  def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,\n               num_hosts, num_cores, run_config):\n    if not embedding_config_spec:\n      raise ValueError('embedding_config_spec cannot be None.')\n\n    self._embedding_config_spec = embedding_config_spec\n    self._train_batch_size = train_batch_size\n    self._eval_batch_size = eval_batch_size\n    self._num_hosts = num_hosts\n    self._num_cores = num_cores\n    self._run_config = run_config\n\n    if embedding_config_spec.feature_columns:\n      self._table_to_config_dict, self._feature_to_config_dict = (\n          get_configs_from_feature_columns(\n              embedding_config_spec.feature_columns))\n    else:\n      self._table_to_config_dict = embedding_config_spec.table_to_config_dict\n      self._feature_to_config_dict = embedding_config_spec.feature_to_config_dict\n    self._partition_strategy = embedding_config_spec.partition_strategy\n    self._mode_to_tpu_embedding_dict = {}\n    self.dummy_table_variables = None\n\n    self._grad_multiplier_fn = (\n        embedding_config_spec.experimental_gradient_multiplier_fn)\n\n  def get_grad_multiplier(self):\n    if self._grad_multiplier_fn:\n      return ops.convert_to_tensor(\n          self._grad_multiplier_fn(tf.compat.v1.train.get_global_step()),\n          dtype=tf.dtypes.float32)\n\n  def has_embedding_tables(self):\n    return bool(self._table_to_config_dict)\n\n  def _create_tpu_embedding(self, mode):\n    \"\"\"Create tpu_embedding.TPUEmbedding based on mode.\"\"\"\n    if mode == model_fn_lib.ModeKeys.TRAIN:\n      batch_size = self._train_batch_size\n    else:\n      batch_size = self._eval_batch_size\n\n    if mode == model_fn_lib.ModeKeys.TRAIN:\n      tpu_embedding_mode = tpu_embedding.TRAINING\n      optimization_parameters = (\n          self._embedding_config_spec.optimization_parameters)\n    elif (mode == model_fn_lib.ModeKeys.EVAL or\n          mode == model_fn_lib.ModeKeys.PREDICT):\n      tpu_embedding_mode = tpu_embedding.INFERENCE\n      optimization_parameters = None\n    else:\n      raise ValueError('Mode {} is not supported.'.format(mode))\n\n    if self._run_config.cluster:\n      master = self._run_config.cluster.master()\n      cluster_spec = self._run_config.cluster.cluster_spec()\n      cluster_def = cluster_spec.as_cluster_def() if cluster_spec else None\n    else:\n      master = (\n          self._run_config.evaluation_master\n          if mode == model_fn_lib.ModeKeys.EVAL else self._run_config.master)\n      cluster_def = None\n    master_job_name = None\n    if self._run_config.tpu_config.tpu_job_name is not None:\n      master_job_name = self._run_config.tpu_config.tpu_job_name\n    tpu_embedding_ = tpu_embedding.TPUEmbedding(\n        self._table_to_config_dict,\n        self._feature_to_config_dict,\n        batch_size,\n        tpu_embedding_mode,\n        master,\n        optimization_parameters,\n        cluster_def,\n        pipeline_execution_with_tensor_core=self._embedding_config_spec\n        .pipeline_execution_with_tensor_core,\n        partition_strategy=self._partition_strategy,\n        profile_data_directory=self._embedding_config_spec\n        .profile_data_directory,\n        master_job_name=master_job_name)\n    return tpu_embedding_\n\n  def get_tpu_embedding(self, mode):\n    if mode not in self._mode_to_tpu_embedding_dict:\n      self._mode_to_tpu_embedding_dict[mode] = (\n          self._create_tpu_embedding(mode))\n    return self._mode_to_tpu_embedding_dict[mode]\n\n\ndef _maybe_dense_to_sparse(tensor):\n  \"\"\"Possibly convert a dense (rank 1 or 2) tensor to a SparseTensor.\"\"\"\n  # If already sparse, return as is.\n  if isinstance(tensor, tf.sparse.SparseTensor):\n    return tensor\n  indices = tf.compat.v1.where(tensor)\n  values = tf.compat.v1.gather_nd(tensor, indices)\n  shape = tf.compat.v1.shape(tensor, out_type=tf.dtypes.int64)\n  return tf.sparse.SparseTensor(indices, values, shape)\n\n\ndef split_inputs(ctx, features, labels, num_cores_per_batch=1):\n  \"\"\"Splits the dense and sparse tensors inside the features and labels.\"\"\"\n  enqueue_datas = collections.OrderedDict()\n\n  if ctx.embedding_config:\n    tpu_embedding_ = ctx.embedding_config.tpu_embedding\n    for feature_key in tpu_embedding_.feature_to_config_dict:\n      sparse_feature = _get_sparse_feature_from_feature(feature_key, features)\n      max_sequence_length = tpu_embedding_.feature_to_config_dict[\n          feature_key].max_sequence_length\n      combiner = tpu_embedding_._table_to_config_dict[\n          tpu_embedding_._feature_to_config_dict[feature_key].table_id].combiner\n      if max_sequence_length > 0:\n        length_feature_name = (\n            tpu_fc.get_sequence_length_feature_key_name_from_feature_key_name(\n                feature_key))\n        length_feature = tf.math.minimum(\n            fc_utils.sequence_length_from_sparse_tensor(sparse_feature),\n            max_sequence_length)\n        length_feature.set_shape(ctx.batch_size_for_input_fn)\n        features[length_feature_name] = length_feature\n      weight_key = tpu_embedding_.feature_to_config_dict[feature_key].weight_key\n      sparse_feature_split = _split_tensor(sparse_feature, num_cores_per_batch)\n      if combiner is None and not isinstance(sparse_feature,\n                                             tf.sparse.SparseTensor):\n        # A dense tensor with no combiner was provided so we assume that each\n        # of the embedding_indices belongs to a different sample (setting\n        # sample_indices to None).\n        if weight_key is not None:\n          raise ValueError(\n              'Found weights {} for weighted_categorical_column, which is not'\n              'compatible with sparse feature {} enqueued as dense tensor.'\n              .format(weight_key, feature_key))\n        enqueue_data = []\n        for i in range(num_cores_per_batch):\n          enqueue_data.append(\n              tpu_embedding.EnqueueData(sparse_feature_split[i]))\n      else:\n        weights = None\n        if isinstance(sparse_feature, tf.sparse.SparseTensor):\n          weights = _get_weights_from_features(weight_key, features)\n          weights_split = _split_tensor(weights, num_cores_per_batch)\n        enqueue_data = []\n        for i in range(num_cores_per_batch):\n          split_weights = weights_split[i] if weights else None\n          enqueue_data.append(\n              tpu_embedding.EnqueueData.from_sparse_tensor(\n                  _maybe_dense_to_sparse(sparse_feature_split[i]),\n                  weights=split_weights))\n      enqueue_datas[feature_key] = enqueue_data\n  if ctx.tensor_core_embedding_columns:\n    # pylint: disable=protected-access\n    for column in ctx.tensor_core_embedding_columns:\n      feature_key = column.categorical_column.key\n      sparse_feature = _get_sparse_feature_from_feature(feature_key, features)\n      padded_values, padded_mask = (\n          tpu_fc_v2.pad_sparse_embedding_lookup_indices(\n              sparse_feature, column._tensor_core_shape[1]))\n      padded_values.set_shape(\n          [ctx.batch_size_for_input_fn, column._tensor_core_shape[1]])\n      padded_mask.set_shape(\n          [ctx.batch_size_for_input_fn, column._tensor_core_shape[1]])\n      features[feature_key] = padded_values\n      mask_key = feature_key + tpu_fc_v2._TENSOR_CORE_MASK_KEY_SUFFIX\n      if mask_key in features:\n        raise ValueError('Mask key {} for Tensor Core embedding is '\n                         'already in use.'.format(mask_key))\n      features[mask_key] = padded_mask\n    # pylint: enable=protected-access\n\n  # Transpose the enqueue_datas dict into a list of dicts\n  enqueue_datas_list = []\n  for i in range(num_cores_per_batch):\n    enqueue_data = {}\n    for key, value in enqueue_datas.items():\n      enqueue_data[key] = value[i]\n    enqueue_datas_list.append(enqueue_data)\n  return features, labels, enqueue_datas_list\n\n\ndef _split_tensor(tensor, num_splits):\n  \"\"\"Splits tensor into num_splits pieces, returns a list of pieces.\"\"\"\n  if tensor is None:\n    return [None] * num_splits\n  elif num_splits <= 0:\n    return ValueError(\n        'Tensors cannot be split into {} pieces.'.format(num_splits))\n  elif num_splits == 1:\n    return [tensor]\n  elif isinstance(tensor, tf.sparse.SparseTensor):\n    return tf.compat.v2.sparse.split(tensor, num_splits, axis=0)\n  else:\n    return tf.split(tensor, num_splits)\n\n\ndef _get_sparse_feature_from_feature(feature_key, features):\n  \"\"\"Pop and return sparse feature.\"\"\"\n  sparse_feature = features.pop(feature_key)\n  if not sparse_feature.dtype.is_integer:\n    raise ValueError('SparseTensor with string as values are not supported. '\n                     'If you are using categorical_column_with_vocabulary_file '\n                     'or categorical_column_with_vocabulary_list, please call '\n                     'your_column.categorical_column._transform_feature({{'\n                     'your_column.key: features[your_column.key]}}) in '\n                     'your input_fn() to convert string to int. '\n                     'feature_key = {}.'.format(feature_key))\n  return sparse_feature\n\n\ndef _get_weights_from_features(weight_key_name, features):\n  \"\"\"Pop and return feature for weights, possibly None.\"\"\"\n  weights = None\n  if weight_key_name is not None:\n    if weight_key_name in features:\n      weights = features.pop(weight_key_name)\n    else:\n      raise ValueError(\n          'Cannot find weights {} for weighted_categorical_column.'\n          ' Please check if the weights are present in feature dict. Also'\n          ' note weight-sharing among weighted_categorical_column is not '\n          'supported on TPU.'.format(weight_key_name))\n    if not isinstance(weights, tf.sparse.SparseTensor):\n      raise ValueError(\n          'weighted_categorical_column with weight key name {} has dense '\n          'weights. Dense weights are not supported on TPU. Please use '\n          'sparse weights instead.'.format(weight_key_name))\n    if weights.dtype is not tf.dtypes.float32:\n      weights = tf.cast(weights, dtype=tf.dtypes.float32)\n  return weights\n\n\ndef get_tpu_embedding_columns(feature_columns):\n  \"\"\"Get feature columns meant to use TPU embedding.\n\n  Args:\n    feature_columns: a list of feature columns.\n\n  Returns:\n    A list of feature columns which can be placed on TPU embedding.\n  \"\"\"\n  tpu_embedding_columns = []\n  for column in feature_columns:\n    if isinstance(column, _TPU_EMBEDDING_COLUMN_CLASSES):\n      tpu_embedding_columns.append(column)\n  return tpu_embedding_columns\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/autotuning_iterations_per_loop_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =============================================================================\n\"\"\"Tests for auto-tuning iterations_per_loop using TPUStopWithAutoTunedStepHook.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport time\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator.tpu import iteration_count_estimator\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\nfrom tensorflow_estimator.python.estimator.tpu import util as util_lib\n\n\nclass IterationsPerLoopParsingTest(tf.test.TestCase):\n\n  def _parse_and_validate_iterations_per_loop(self, value, expected_value,\n                                              expected_unit):\n    d = util_lib.parse_iterations_per_loop(value)\n    self.assertTrue(d)\n    self.assertEqual(d.value, expected_value)\n    self.assertEqual(d.unit, expected_unit)\n\n  def _parse_and_validate_invalid_iterations_per_loop(self, value):\n    with self.assertRaises(ValueError) as ve:\n      self._parse_and_validate_iterations_per_loop(value, 0, '')\n      self.assertTrue(\n          ve.exception.message.startswith(\n              'Invalid `iterations_per_loop` value.'))\n\n  def test_parsing_iterations_per_loop(self):\n    \"\"\"Tests parsing valid and invalid `iterations_per_loop` values.\"\"\"\n\n    self._parse_and_validate_iterations_per_loop(1, 1, 'count')\n    self._parse_and_validate_iterations_per_loop('1', 1, 'count')\n    self._parse_and_validate_iterations_per_loop(2, 2, 'count')\n    self._parse_and_validate_iterations_per_loop(10, 10, 'count')\n    self._parse_and_validate_iterations_per_loop(123, 123, 'count')\n    self._parse_and_validate_iterations_per_loop('123', 123, 'count')\n    self._parse_and_validate_iterations_per_loop('1h', 3600, 'seconds')\n    self._parse_and_validate_iterations_per_loop('1m', 60, 'seconds')\n    self._parse_and_validate_iterations_per_loop('1s', 1, 'seconds')\n    self._parse_and_validate_iterations_per_loop('10h', 10 * 3600, 'seconds')\n    self._parse_and_validate_iterations_per_loop('10m', 10 * 60, 'seconds')\n    self._parse_and_validate_iterations_per_loop('10s', 10, 'seconds')\n    self._parse_and_validate_iterations_per_loop('100h', 100 * 3600, 'seconds')\n    self._parse_and_validate_iterations_per_loop('1000m', 1000 * 60, 'seconds')\n    self._parse_and_validate_iterations_per_loop('10800s', 10800, 'seconds')\n    self._parse_and_validate_invalid_iterations_per_loop(+0)\n    self._parse_and_validate_invalid_iterations_per_loop(0)\n    self._parse_and_validate_invalid_iterations_per_loop(-0)\n    self._parse_and_validate_invalid_iterations_per_loop(-0o12)\n    self._parse_and_validate_invalid_iterations_per_loop('012')\n    self._parse_and_validate_invalid_iterations_per_loop('001')\n    self._parse_and_validate_invalid_iterations_per_loop('0')\n    self._parse_and_validate_invalid_iterations_per_loop('01')\n    self._parse_and_validate_invalid_iterations_per_loop('-1')\n    self._parse_and_validate_invalid_iterations_per_loop('-0h')\n    self._parse_and_validate_invalid_iterations_per_loop('0h')\n    self._parse_and_validate_invalid_iterations_per_loop('0s')\n    self._parse_and_validate_invalid_iterations_per_loop('0m')\n    self._parse_and_validate_invalid_iterations_per_loop('-1h')\n    self._parse_and_validate_invalid_iterations_per_loop('-1s')\n    self._parse_and_validate_invalid_iterations_per_loop('-1m')\n\n\nclass IterationPredictorTest(tf.test.TestCase):\n\n  def setUp(self):\n    self.estimator = iteration_count_estimator.IterationCountEstimator(\n        capacity=5)\n\n  def test_empty(self):\n    \"\"\"Tests on empty queue.\"\"\"\n    self.assertEqual(self.estimator._min_iterations, self.estimator.get(1))\n    self.assertEqual(self.estimator._min_iterations, self.estimator.get(10))\n\n  def test_reset(self):\n    \"\"\"Tests reset states.\"\"\"\n    self.assertEqual(0, self.estimator._sample_count)\n    self.assertEqual(self.estimator._min_iterations, self.estimator.get(50))\n    self.assertEqual(0, len(self.estimator._buffer_wheel))\n    self.estimator._reset()\n    self.assertEqual(0, self.estimator._sample_count)\n    self.assertEqual(self.estimator._min_iterations, self.estimator.get(100))\n    self.assertEqual(0, len(self.estimator._buffer_wheel))\n    self.estimator.update(9, 1)\n    self.assertEqual(1, self.estimator._sample_count)\n    self.assertEqual(self.estimator._min_iterations, self.estimator.get(8))\n\n  def test_invalid_update(self):\n    \"\"\"Tests reject invalid update.\"\"\"\n    self.estimator._reset()\n    self.estimator.update(0, 0)\n    self.assertEqual(0, len(self.estimator._buffer_wheel))\n    with self.assertRaises(ValueError) as ve:\n      self.assertEqual(self.estimator._min_iterations, self.estimator.get(-1))\n      self.assertIn('Invalid `total_secs`', ve.message)\n    with self.assertRaises(ValueError) as ve:\n      self.assertEqual(self.estimator._min_iterations, self.estimator.get(0))\n      self.assertIn('Invalid `total_secs`', ve.message)\n\n  def test_zero_mean(self):\n    \"\"\"Tests getting estimate when the elapsed time mean value is zero.\"\"\"\n    self.estimator.update(0, 1)\n    self.assertEqual(self.estimator._min_iterations, self.estimator.get(10))\n    self.estimator.update(0, 1)\n    self.estimator.update(0, 1)\n    self.assertEqual(self.estimator._min_iterations, self.estimator.get(10))\n\n  def test_diff_less_than_percentage(self):\n    \"\"\"Tests computing diff less than a percentage.\"\"\"\n    self.assertTrue(self.estimator._diff_less_than_percentage(5, 10, 50))\n    self.assertTrue(self.estimator._diff_less_than_percentage(2.5, 10, 75))\n    self.assertTrue(self.estimator._diff_less_than_percentage(10, 10, 5))\n    self.assertTrue(self.estimator._diff_less_than_percentage(9.5, 10, 5))\n    self.assertTrue(self.estimator._diff_less_than_percentage(9.6, 10, 5))\n    self.assertFalse(self.estimator._diff_less_than_percentage(11, 10, 5))\n    self.assertFalse(self.estimator._diff_less_than_percentage(20, 10, 5))\n    self.assertTrue(self.estimator._diff_less_than_percentage(10.3, 10, 5))\n    self.assertTrue(self.estimator._diff_less_than_percentage(10.5, 10, 5))\n    self.assertFalse(self.estimator._diff_less_than_percentage(10.6, 10, 5))\n    self.assertFalse(self.estimator._diff_less_than_percentage(1, 10, 5))\n    self.assertFalse(self.estimator._diff_less_than_percentage(9, 10, 5))\n    with self.assertRaises(ValueError) as ve:\n      self.assertTrue(self.estimator._diff_less_than_percentage(0, 10, 5))\n      self.assertIn('Invalid `actual` value', ve.message)\n    with self.assertRaises(ValueError) as ve:\n      self.assertTrue(self.estimator._diff_less_than_percentage(10, 0, 5))\n      self.assertIn('Invalid `target` value.', ve.message)\n\n  def test_mean_runtime_secs(self):\n    \"\"\"Tests computing mean of step time secs.\"\"\"\n    self.assertEqual(0.0, self.estimator._mean_runtime_secs())\n    self.estimator.update(1, 5)\n    self.assertEqual(1.0, self.estimator._mean_runtime_secs())\n    self.estimator._reset()\n    self.estimator.update(2, 3)\n    self.estimator.update(2, 3)\n    self.estimator.update(2, 3)\n    self.assertEqual(2.0, self.estimator._mean_runtime_secs())\n    self.estimator._reset()\n    self.estimator.update(1, 3)\n    self.estimator.update(2, 3)\n    self.assertEqual((1.0 + 2.0) / 2, self.estimator._mean_runtime_secs())\n\n  def test_mean_step_time_secs(self):\n    \"\"\"Tests computing mean of step time secs.\"\"\"\n    self.assertEqual(0.0, self.estimator._mean_step_time_secs())\n    self.estimator.update(1, 5)\n    self.assertEqual(1.0 / 5, self.estimator._mean_step_time_secs())\n    self.estimator._reset()\n    self.estimator.update(2, 3)\n    self.estimator.update(2, 3)\n    self.estimator.update(2, 3)\n    self.assertEqual(2.0 / 3, self.estimator._mean_step_time_secs())\n    self.estimator._reset()\n    self.estimator.update(1, 3)\n    self.estimator.update(2, 3)\n    self.assertEqual((1.0 / 3 + 2.0 / 3) / 2,\n                     self.estimator._mean_step_time_secs())\n\n  def _test_std_step_time_secs(self):\n    \"\"\"Tests computing std deviation of the step time secs.\"\"\"\n    self.assertEqual(0.0, self.estimator._std_step_time_secs())\n    self.estimator.update(1, 5)\n    self.estimator.update(1, 5)\n    self.assertEqual(0.0, self.estimator._std_step_time_secs())\n    self.estimator.update(4, 5)\n    self.assertAlmostEqual(0.283, self.estimator._std_step_time_secs(), 3)\n    self.estimator.update(5, 5)\n    self.assertAlmostEqual(0.357, self.estimator._std_step_time_secs(), 3)\n\n  def test_buffer_capacity(self):\n    \"\"\"Tests to make sure wheel is kept at its capacity.\"\"\"\n    self.estimator._reset(capacity=3)\n    self.assertEqual(0, len(self.estimator._buffer_wheel))\n    self.assertEqual(3, self.estimator._capacity)\n    for _ in range(0, self.estimator._capacity):\n      self.estimator.update(1, 1)\n    self.assertEqual(3, len(self.estimator._buffer_wheel))\n    self.assertEqual(1.0, self.estimator._mean_runtime_secs())\n    self.assertEqual(1.0, self.estimator._mean_step_time_secs())\n    for _ in range(0, self.estimator._capacity):\n      self.estimator.update(3, 2)\n    self.assertEqual(3, len(self.estimator._buffer_wheel))\n    self.assertEqual(3.0, self.estimator._mean_runtime_secs())\n    self.assertEqual(1.5, self.estimator._mean_step_time_secs())\n\n  def test_partial_wheel(self):\n    \"\"\"Tests getting estimate when the circular buffer is not full.\"\"\"\n    self.assertEqual(0, self.estimator._sample_count)\n    self.estimator.update(5.0, 1)\n    self.assertEqual(1, self.estimator._sample_count)\n    self.assertEqual(5.0, self.estimator._mean_runtime_secs())\n    self.assertEqual(5.0, self.estimator._mean_step_time_secs())\n    self.assertEqual(2, self.estimator.get(10))\n    self.estimator.update(5.0, 1)\n    self.assertEqual(2, self.estimator._sample_count)\n    self.assertEqual(5.0, self.estimator._mean_runtime_secs())\n    self.assertEqual(5.0, self.estimator._mean_step_time_secs())\n    self.assertEqual(3, self.estimator.get(15))\n    self.estimator.update(5.0, 1)\n    self.assertEqual(3, self.estimator._sample_count)\n    self.assertEqual(5.0, self.estimator._mean_runtime_secs())\n    self.assertEqual(5.0, self.estimator._mean_step_time_secs())\n    self.assertEqual(2, self.estimator.get(10))\n\n  def test_update_convergence(self):\n    \"\"\"Tests iterative search convergence.\"\"\"\n    for _ in range(0, self.estimator._capacity):\n      self.estimator.update(2.0, 4)\n    self.assertEqual(2, self.estimator._mean_runtime_secs())\n    self.assertEqual(0.5, self.estimator._mean_step_time_secs())\n\n    iterations = 4\n    target_elapsed_time = 10\n    actual_elapsed_time = 2\n    secs_per_iterations = actual_elapsed_time / iterations\n    for _ in range(0, 5):\n      self.estimator.update(actual_elapsed_time, iterations)\n      iterations = self.estimator.get(target_elapsed_time)\n      actual_elapsed_time = iterations * secs_per_iterations\n    self.assertLessEqual(abs(actual_elapsed_time - target_elapsed_time), 1)\n\n\nclass TPUStopAtStepHookTest(tf.test.TestCase):\n\n  def test_invalid_parameters_on_construction(self):\n    \"\"\"Tests invalid parameters on construction.\"\"\"\n    with self.assertRaises(ValueError) as ve:\n      tpu_estimator._TPUStopAtStepHook(\n          util_lib.IterationsPerLoopCounter(value=10, unit='count'),\n          num_steps=None,\n          final_step=None)\n      self.assertEqual(ve.exception.message,\n                       'One of num_steps or final_step must be specified.')\n\n    with self.assertRaises(ValueError) as ve:\n      tpu_estimator._TPUStopAtStepHook(\n          util_lib.IterationsPerLoopCounter(value=10, unit='count'),\n          num_steps=10,\n          final_step=100)\n      self.assertEqual(ve.exception.message,\n                       'Only one of num_steps or final_step can be specified.')\n\n    with self.assertRaises(ValueError) as ve:\n      tpu_estimator._TPUStopAtStepHook(\n          util_lib.IterationsPerLoopCounter(value=10, unit='secs'),\n          num_steps=10,\n          final_step=100)\n      self.assertEqual(\n          ve.exception.message,\n          'Only `count` or `seconds` are accepted as the `iterations_per_loop` '\n          'unit.')\n\n  def _validate_hook_life_cycle(self, iterations_per_loop_counter, num_steps):\n    \"\"\"Test execute hook life-cycle.\n\n    This test validates:\n    - Correctly updating the iterations both for `iterations_per_loop_counter`\n      specified as both `count` and `seconds`\n    - Terminates the session.run() by signaling termination `request_stop()`\n    - The computation of the final iterations count when the remaining step\n      count is smaller than the iterations_per_loop_counter.value.\n\n    Args:\n      iterations_per_loop_counter: This is the number of train steps running in\n        TPU before returning to CPU host for each `Session.run`. Can be\n        specified as `count` or `seconds`.\n      num_steps: Number of steps to execute.\n    \"\"\"\n    with self.test_session() as sess:\n      global_step_tensor = tf.compat.v1.train.get_or_create_global_step(\n          sess.graph)\n      global_step_tensor.load(0, session=sess)\n      self.assertEqual(sess.run(global_step_tensor), 0)\n\n      default_iterations = 1\n      hook = tpu_estimator._TPUStopAtStepHook(\n          iterations_per_loop_counter, num_steps=num_steps)\n      self.assertEqual(default_iterations, hook._next_iteration_count)\n      self.assertEqual(num_steps, hook._num_steps)\n      self.assertEqual(None, hook._final_step)\n      self.assertEqual(iterations_per_loop_counter.value,\n                       hook._iterations_per_loop_counter.value)\n      self.assertEqual(iterations_per_loop_counter.unit,\n                       hook._iterations_per_loop_counter.unit)\n\n      def _step(hook, is_final, expected_iterations):\n        hook.begin()\n        hook.after_create_session(sess, None)\n\n        class RunContextMock(object):\n\n          def __init__(self, session):\n            self.session = session\n            self.stop = False\n\n          def request_stop(self):\n            self.stop = True\n\n        class RunValues(object):\n\n          def __init__(self, elapsed_time_secs):\n            self.results = {'elapsed_time': elapsed_time_secs}\n\n        run_context = RunContextMock(sess)\n        run_values = RunValues(1)\n        time.sleep(1.0)\n        hook.after_run(run_context, run_values)\n        if is_final:\n          self.assertEqual(hook._next_iteration_count, expected_iterations)\n          self.assertEqual(run_context.stop, is_final)\n        else:\n          self.assertLessEqual(\n              abs(hook._next_iteration_count - expected_iterations), 1)\n\n      # Estimates iterations when global_step < final_step.\n      global_step = sess.run(tf.compat.v1.train.get_global_step())\n      self.assertEqual(global_step, 0)\n      _step(hook, is_final=False, expected_iterations=3)\n\n      # Estimates iterations when global_step < final_step.\n      global_step_tensor.load(2, session=sess)\n      _step(hook, is_final=False, expected_iterations=3)\n\n      # Estimates iterations when global_step < final_step, and\n      # (final_step - global_step) < estimated-iterations.\n      global_step_tensor.load(4, session=sess)\n      _step(hook, is_final=False, expected_iterations=1)\n\n      # Estimates iterations when global_step == final_step.\n      global_step_tensor.load(5, session=sess)\n      _step(hook, is_final=True, expected_iterations=0)\n\n  @test_util.deprecated_graph_mode_only\n  def test_hook_life_cycle(self):\n    \"\"\"Tests update iterations.\"\"\"\n    self._validate_hook_life_cycle(\n        util_lib.IterationsPerLoopCounter(value=3, unit='seconds'), 5)\n    self._validate_hook_life_cycle(\n        util_lib.IterationsPerLoopCounter(value=3, unit='count'), 5)\n\n  def _validate_initialization(self, iterations_per_loop_counter, num_steps):\n    with self.test_session() as sess:\n      global_step_tensor = tf.compat.v1.train.get_or_create_global_step(\n          sess.graph)\n      global_step_tensor.load(0, session=sess)\n      self.assertEqual(sess.run(global_step_tensor), 0)\n\n      hook = tpu_estimator._TPUStopAtStepHook(\n          iterations_per_loop_counter, num_steps=num_steps)\n      self.assertEqual(1, hook._next_iteration_count)\n      self.assertEqual(num_steps, hook._num_steps)\n      self.assertEqual(None, hook._final_step)\n      self.assertEqual(iterations_per_loop_counter.value,\n                       hook._iterations_per_loop_counter.value)\n      self.assertEqual(iterations_per_loop_counter.unit,\n                       hook._iterations_per_loop_counter.unit)\n      if iterations_per_loop_counter.unit == 'count':\n        with self.assertRaises(AttributeError) as ve:\n          _ = hook.iteration_count_estimator\n          self.assertIn('object has no attribute', ve.message)\n      else:\n        self.assertIsInstance(hook._iteration_count_estimator,\n                              iteration_count_estimator.IterationCountEstimator)\n\n  @test_util.deprecated_graph_mode_only\n  def test_initialization(self):\n    \"\"\"Tests initialization.\n\n    This test validates initialization of the Hook using both specifying\n    `iterations_per_loop` as raw `count` and `seconds`.\n    \"\"\"\n    self._validate_initialization(\n        util_lib.IterationsPerLoopCounter(value=3, unit='seconds'), 3)\n    self._validate_initialization(\n        util_lib.IterationsPerLoopCounter(value=600, unit='seconds'), 1)\n    self._validate_initialization(\n        util_lib.IterationsPerLoopCounter(value=3600, unit='seconds'), 5)\n    self._validate_initialization(\n        util_lib.IterationsPerLoopCounter(value=3, unit='count'), 100)\n    self._validate_initialization(\n        util_lib.IterationsPerLoopCounter(value=100, unit='count'), 10)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/error_handling.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ===================================================================\n\"\"\"ErrorRendezvous handler for collecting errors from multiple threads.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport contextlib\nimport sys\nimport threading\nimport time\n\nimport six\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.tools import analytics\n\n_UNINTERESTING_ERRORS = (tf.errors.CancelledError,)\n_IGNORED_ERRORS = (\n    tf.errors.AbortedError,\n    tf.errors.UnavailableError,\n)\n\n_CHECK_NUMERIC_OP_NAME = 'CheckNumerics'\n\n\nclass ErrorRendezvous(object):\n  \"\"\"Resolve errors from multiple threads during TPU execution.\n\n  TPU errors can occur on the infeed or outfeed threads as well as the main\n  training thread.\n\n  Depending on which thread \"wins\" and receives the session error first, we may\n  end up showing users a confusing and non-actionable error message (session\n  cancelled) instead of a root cause (e.g. a bad filename).\n\n  The rendezvous object provides a location to capture these errors until all\n  threads terminate.  At that point we can choose the most informative error\n  to report.\n  \"\"\"\n\n  def __init__(self, num_sources):\n    # string -> (message, traceback)\n    self._errors = {}\n    self._num_sources = num_sources\n    self._session_cancel_timer = None\n\n  def record_error(self, source, exc_info, session=None):\n    \"\"\"Report an exception from the given source.\n\n    If a session is passed, a timer will be registered to close it after a few\n    seconds.  This is necessary to ensure the main training loop does not hang\n    if an infeed/oufeed error occurs.  We sleep a few seconds to allow a more\n    interesting error from another thread to propagate.\n\n    Args:\n      source: string, source of the error\n      exc_info: Output from `sys.exc_info` (type, value, traceback)\n      session: Session to close after delay.\n    \"\"\"\n    _, value, _ = exc_info\n    # Ignore errors already handled by MonitoredSession\n    if isinstance(value, _IGNORED_ERRORS):\n      return\n\n    self._errors[source] = exc_info\n\n    # If the error is a numeric type, e.g., NaN error, we can assume that the\n    # loop execution completed successfully. In this case, we can skip the\n    # `session.close()` logic and wait for the infeed/outfeed threads to\n    # complete as normal.\n    try:\n      if value.op.type == _CHECK_NUMERIC_OP_NAME:\n        analytics.track_numerical_issues(exc_info)\n        return\n    except AttributeError as _:\n      pass\n\n    if session is not None and self._session_cancel_timer is None:\n\n      def _cancel_session():\n        time.sleep(5)\n        tf.compat.v1.logging.error('Closing session due to error %s' % value)\n        try:\n          session.close()\n        except:  # pylint: disable=bare-except\n          tf.compat.v1.logging.error(\n              '\\n\\n\\nFailed to close session after error.'\n              'Other threads may hang.\\n\\n\\n')\n\n      self._session_cancel_timer = threading.Thread(target=_cancel_session,)\n      self._session_cancel_timer.daemon = True\n      self._session_cancel_timer.start()\n\n  def record_done(self, source):\n    \"\"\"Mark execution source `source` as done.\n\n    If an error was originally reported from `source` it is left intact.\n\n    Args:\n      source: `str`, source being recorded\n    \"\"\"\n    tf.compat.v1.logging.info('%s marked as finished', source)\n    if source not in self._errors:\n      self._errors[source] = None\n\n  @contextlib.contextmanager\n  def catch_errors(self, source, session=None):\n    \"\"\"Context manager to report any errors within a block.\"\"\"\n    try:\n      yield\n    except Exception:  # pylint: disable=broad-except\n      self.record_error(source, sys.exc_info(), session)\n\n  def raise_errors(self, timeout_sec=0):\n    \"\"\"Wait for up to `timeout` seconds for all error sources to finish.\n\n    Preferentially raise \"interesting\" errors (errors not in the\n    _UNINTERESTING_ERRORS) set.\n\n    Args:\n      timeout_sec: Seconds to wait for other error sources.\n    \"\"\"\n    for _ in range(timeout_sec):\n      if len(self._errors) == self._num_sources:\n        break\n      time.sleep(1)\n\n    kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None]\n\n    # First check for any interesting errors, then fall back on the session\n    # cancelled errors etc.\n    for k, (typ, value, traceback) in kept_errors:\n      if isinstance(value, _UNINTERESTING_ERRORS):\n        continue\n      else:\n        tf.compat.v1.logging.warn('Reraising captured error')\n        six.reraise(typ, value, traceback)\n\n    for k, (typ, value, traceback) in kept_errors:\n      tf.compat.v1.logging.warn('Reraising captured error')\n      six.reraise(typ, value, traceback)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/error_handling_test.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Error Handling tests.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.tpu import error_handling\n\n\nclass ErrorHandlingTest(tf.test.TestCase):\n\n  def catch_and_raise(self, error):\n    er = error_handling.ErrorRendezvous(1)\n    with er.catch_errors(source='infeed'):\n      raise error\n    er.raise_errors()\n\n  def testInterestingError(self):\n    with self.assertRaises(tf.errors.InternalError):\n      self.catch_and_raise(tf.errors.InternalError('message', None, None))\n\n  def testIgnoredError(self):\n    \"\"\"Expect no error to be raised.\"\"\"\n    self.catch_and_raise(tf.errors.AbortedError('message', None, None))\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/iteration_count_estimator.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# =============================================================================\n\"\"\"Estimator that uses past runtime samples to estimate iterations count.\n\nThe estimator helps simplify determining the number of iterations count to spend\non a given alloted time budget. The estimate will get adjusted over time as the\nestimator learns more from collecting per iteration runtime samples.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\n\nimport numpy as np\nimport tensorflow as tf\n\nRuntimeCounter = collections.namedtuple(\n    \"RuntimeCounter\", [\"runtime_secs\", \"steps\", \"step_time_secs\"])\n\n\nclass IterationCountEstimator(object):\n  \"\"\"Estimates iterations count using past iterations runtime.\n\n  The estimator collects iterations elapsed time (in seconds) and store it into\n  a circular buffer. As it learns enough samples, it computes the mean value of\n  the past observed iterations elapsed time to estimate the number of iterations\n  count to run within the alloted time budget in seconds.\n\n  To keep the buffer from growing indefinitely, we limit the size by the virtue\n  of using circular buffer. As it uses the mean of iterations runtime to compute\n  the iterations count estimate, setting a larger buffer size will smooth out\n  the estimation. Once the buffer is getting filled up, older values will be\n  dequeued in FIFO order. Setting larger buffer size will make the estimator\n  less sensitive to runtime fluctuations but will result in slower convergence.\n  For faster convergence buffer size can be set smaller but more prone to\n  runtime fluctuations.\n\n  As a safety feature, the estimator will return default iterations value,\n  when:\n  1. The circular buffer is empty (initially).\n  2. The user input is invalid.\n  \"\"\"\n\n  def __init__(self, capacity=20):\n    \"\"\"Constructs a new `IterationsEstimator` instance.\n\n    Args:\n      capacity: Size of circular buffer to hold timer values. Each timer value\n        represents the time spent on the last iterations.\n\n    Raises:\n      ValueError: If one or more parameters specified is invalid.\n    \"\"\"\n    self._reset(capacity=capacity)\n\n  def _reset(self, capacity=20):\n    \"\"\"Resets internal variables.\"\"\"\n    if capacity <= 0:\n      raise ValueError(\"IterationCountEstimator `capacity` must be positive. \"\n                       \"Actual:%d.\" % capacity)\n    # A circular buffer with fixed capacity to store the observation time values\n    # and once the buffer is full, the oldest value will be evicted.\n    self._buffer_wheel = collections.deque([])\n    self._capacity = capacity\n    self._min_iterations = 1\n    self._last_iterations = self._min_iterations\n    self._sample_count = 0\n\n  def _mean_runtime_secs(self):\n    return np.mean(self._buffer_wheel, axis=0)[0] if self._buffer_wheel else 0\n\n  def _mean_step_time_secs(self):\n    return np.mean(self._buffer_wheel, axis=0)[2] if self._buffer_wheel else 0\n\n  def _std_step_time_secs(self):\n    return np.std(self._buffer_wheel, axis=0)[2] if self._buffer_wheel else 0\n\n  def _diff_less_than_percentage(self, actual, target, percentage):\n    \"\"\"Checks if `actual` value is within a `percentage` to `target` value.\n\n    Args:\n      actual: Actual value.\n      target: Target value.\n      percentage: Max percentage threshold.\n\n    Returns:\n      True if the ABS(`actual` - `target`) is less than or equal to `percentage`\n        , otherwise False.\n\n    Raise:\n      ValueError: If `total_secs` value is not positive.\n    \"\"\"\n    if actual == 0:\n      raise ValueError(\"Invalid `actual` value. Value must not be zero.\")\n    if target == 0:\n      raise ValueError(\"Invalid `target` value. Value must not be zero.\")\n    return (float(abs(target - actual)) / target) <= percentage * 0.01\n\n  def _is_step_time_stable(self):\n    \"\"\"Checks if the step time has stabilized.\n\n    We define stability a function of small stdev and after running for some\n    time.\n\n    Returns:\n      True if stability is reached, False otherwise.\n    \"\"\"\n    std = self._std_step_time_secs()\n    return std < 0.03 and self._sample_count > self._capacity\n\n  def update(self, runtime_secs, count):\n    \"\"\"Updates the unit time spent per iteration.\n\n    Args:\n      runtime_secs: The total elapsed time in seconds.\n      count: The number of iterations.\n    \"\"\"\n    if runtime_secs <= 0.0:\n      tf.compat.v1.logging.debug(\n          \"Invalid `runtime_secs`. Value must be positive. Actual:%.3f.\",\n          runtime_secs)\n      return\n    if count <= 0.0:\n      tf.compat.v1.logging.debug(\n          \"Invalid samples `count`. Value must be positive. Actual:%d.\", count)\n      return\n\n    if len(self._buffer_wheel) >= self._capacity:\n      self._buffer_wheel.popleft()\n    step_time_secs = float(runtime_secs) / count\n    self._buffer_wheel.append(\n        RuntimeCounter(\n            runtime_secs=runtime_secs,\n            steps=count,\n            step_time_secs=step_time_secs))\n    self._sample_count += 1\n\n  def get(self, total_secs):\n    \"\"\"Gets the iterations count estimate.\n\n    If recent predicted iterations are stable, re-use the previous value.\n    Otherwise, update the prediction value based on the delta between the\n    current prediction and the expected number of iterations as determined by\n    the per-step runtime.\n\n    Args:\n      total_secs: The target runtime in seconds.\n\n    Returns:\n      The number of iterations as estimate.\n\n    Raise:\n      ValueError: If `total_secs` value is not positive.\n    \"\"\"\n    if total_secs <= 0:\n      raise ValueError(\n          \"Invalid `total_secs`. It must be positive number. Actual:%d\" %\n          total_secs)\n    if not self._buffer_wheel:\n      tf.compat.v1.logging.debug(\n          \"IterationCountEstimator has no sample(s). Returns min iterations:%d.\",\n          self._min_iterations)\n      return self._min_iterations\n\n    mean_runtime_secs = self._mean_runtime_secs()\n    mean_step_time_secs = self._mean_step_time_secs()\n    std_step_time_secs = self._std_step_time_secs()\n    projected_iterations = total_secs / mean_step_time_secs\n    last_runtime_secs = self._buffer_wheel[-1].runtime_secs\n    delta_iterations = projected_iterations - self._last_iterations\n    # Stabilizes the search once it is close enough to the target runtime and\n    # the step time is stable within range bound.\n    if ((self._diff_less_than_percentage(last_runtime_secs, total_secs, 10) or\n         self._diff_less_than_percentage(mean_runtime_secs, total_secs, 5)) and\n        self._is_step_time_stable()):\n      delta_iterations = 0\n    self._last_iterations += delta_iterations\n    self._last_iterations = max(self._last_iterations, self._min_iterations)\n    tf.compat.v1.logging.info(\n        \"IterationCountEstimator -- target_runtime:%.3fs. last_runtime:%.3fs. \"\n        \"mean_runtime:%.3fs. last_step_time:%.3f. std_step_time:%.3f. \"\n        \"mean_step_time:%.3fs. delta_steps:%.2f. prev_steps:%.2f. \"\n        \"next_steps:%.2f.\", total_secs, last_runtime_secs, mean_runtime_secs,\n        self._buffer_wheel[-1].step_time_secs, std_step_time_secs,\n        mean_step_time_secs, delta_iterations, self._buffer_wheel[-1].steps,\n        self._last_iterations)\n    return int(self._last_iterations + 0.5)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/spatial_partitioning_api.md",
    "content": "# Spatial partitioning\n\nSpatial partitioning allows us to run models with larger input images. Typically\nthese models will be too large to fit on a single TPU core.\n\nSpatial partitioning uses multiple cores to process different parts of the input\ntensor. Each core communicates with the other cores when necessary to merge\noverlapping parts of the computation. All the complicated merging logic is\nimplemented in the XLA compiler, therefore you only need to configure how the\ninputs to your model are partitioned.\n\nNote: Spatial partitioning only distributes activations across multiple cores.\nEach core still maintains its own copy of the model weights. For most image\nmodel, activations use more memory than the model weights.\n\n## Enabling Spatial Partitioning with TPUEstimator\n\nSpatial partitioning doesn't require any code change in your model. You only \nneed to specify the spatial partition parameters in your TPUConfig.\n\n```\ntpu_config=tpu_config.TPUConfig(\n    iterations_per_loop=100,\n    num_cores_per_replica=4,\n    per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2,\n    input_partition_dims=[[1, 4, 1, 1], None]]\n\n```\n\n`per_host_input_for_training` must be set to PER_HOST_V2 for spatial\npartitioning: this means you must have a tf.data based input pipeline.\n`num_cores_per_replica` determines the maximum number partitions we can split.\n`input_partition_dims` is a list with two elements: `feature_partition_dims` and\n`label_partition_dims` describes how to partition the input tensors. The\nstructure of `feature_partition_dims` and `label_partition_dims` must match the\nstructure of features and labels from input_fn.\n\n### Partitioning when features and labels are single tensors\n\n`features` or `labels` can be a single tensor. In this case,\n`feature_partition_dims` or `label_partition_dims` must be a list/tuple of\nintegers or None. The length of the list/tuple must equal to the number of\ndimensions of the tensor. For example, if `features` is an image tensor with\nshape [N, H, W, C], the `feature_partition_dims` must be a list/tuple with 4\nintegers.\n\n```\nfeatures = image_tensor # [N, H, W, C]\nlabels = class_label # [N]\n\ninput_partition_dims = [[1,4,1,1], None]\n\n```\n\n### Partitioning when features or labels are a dictionary\n\n`features` or `labels` can alternatively be a dictionary from `feature_name` to\na `Tensor`. In this case `feature_partition_dims` or `label_partition_dims` must\nbe a dict with exactly the same keys, and the value is a list/tuple of integers\nor None.\n\n```\nfeatures = {'image': image_tensor, 'image_mask': mask_tensor}\nlabels =  {'class_label': class_id, 'mask': mask_id}\n\ninput_partition_dims = [\n   {'image': [1,4,1,1], 'image_mask': [1, 2, 2,1]},\n   {'class_label': [1], mask: None}]\n\n```\n\nIn this example, both `features` and `labels` are dictionaries. Therefore the\n`input_partition_dims` contains two dicts with the same structure: the first\ndict in `input_partition_dims` has two keys ‘image’ and ‘image_mask’ to match\nthe tensors in features. The value is a list of integers describes how to\npartition the tensor. 'class_label': [1] means we send the class_label tensor to\ncore 0 only.\n\n### Partitioning when features are a dict, labels are a single tensor\n\n`features` and `labels` could be any of the aforementation’s format. The rule\nfor `feature_partition_dims` and `label_partition_dims` are applied separately.\n\n```\nfeatures = {'image': image_tensor, 'image_mask': mask_tensor}\nlabels =  class_label # [N]\n\ninput_partition_dims = [\n   {'image': [1,4,1,1], 'image_mask': [1, 2, 2,1]},\n   [1]]\n\n```\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_config.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ===================================================================\n\"\"\"A RunConfig subclass with TPU support.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport json\nimport os\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.tpu import util as util_lib\n\n# pylint: disable=protected-access\n_TF_CONFIG_ENV = run_config_lib._TF_CONFIG_ENV\n_SERVICE_KEY = run_config_lib._SERVICE_KEY\n_TPU_WORKER_JOB_NAME = 'tpu_worker_job_name'\n# pylint: enable=protected-access\n\n\n@estimator_export(v1=['estimator.tpu.InputPipelineConfig'])\nclass InputPipelineConfig(object):\n  r\"\"\"Please see the definition of these values in TPUConfig.\n\n  @compatibility(TF2)\n  TPU Estimator manages its own TensorFlow graph and session, so it is not\n  compatible with TF2 behaviors. We recommend that you migrate to the newer\n  `tf.distribute.TPUStrategy`. See the\n  [TPU guide](https://www.tensorflow.org/guide/tpu) for details.\n  @end_compatibility\n  \"\"\"\n  PER_SHARD_V1 = 1\n  PER_HOST_V1 = 2\n  PER_HOST_V2 = 3\n  BROADCAST = 4\n  SLICED = 5\n\n\n@estimator_export(v1=['estimator.tpu.TPUConfig'])\nclass TPUConfig(\n    collections.namedtuple('TPUConfig', [\n        'iterations_per_loop',\n        'num_shards',\n        'num_cores_per_replica',\n        'per_host_input_for_training',\n        'tpu_job_name',\n        'initial_infeed_sleep_secs',\n        'input_partition_dims',\n        'eval_training_input_configuration',\n        'experimental_host_call_every_n_steps',\n        'experimental_allow_per_host_v2_parallel_get_next',\n        'experimental_feed_hook',\n    ])):\n  r\"\"\"TPU related configuration required by `TPUEstimator`.\n\n  Args:\n    iterations_per_loop: This is the number of train steps running in TPU system\n      before returning to CPU host for each `Session.run`. This means global\n      step is increased `iterations_per_loop` times in one `Session.run`. It is\n      recommended to be set as number of global steps for next checkpoint. Note\n      that in evaluation don't use this value, instead we run total eval `steps`\n      on TPU for a single `Session.run`.\n      [Experimental]: `iterations_per_loop` can be specified as a time interval.\n        To specify N seconds in one `Session.run`, one can specify it as `Ns`\n        and substitute the N with the N with the number of desired seconds.\n        Alternatively, the unit of time can also be specified in minutes or\n        hours, e.g. `3600s` or `60m` or `1h`.\n    num_shards: (Deprecated, ignored by TPUEstimator). The number of model\n      replicas in the system. For non-model-parallelism case, this number equals\n      the total number of TPU cores. For model-parallelism, the total number of\n      TPU cores equals num_cores_per_replica * num_shards.\n    num_cores_per_replica: Defaults to `None`, which disables model parallelism.\n      An integer which describes the number of TPU cores per model replica. This\n      is required by model-parallelism which enables partitioning the model to\n      multiple cores. Currently num_cores_per_replica must be 1, 2, 4, or 8.\n    per_host_input_for_training: If `True`, for `PER_HOST_V1`, the `input_fn` is\n      invoked once on each host, and the number of hosts must be smaller or\n      equal to the number of replicas. For PER_HOST_V2, the `input_fn` is\n      invoked once for each host (if the number of hosts is less than the number\n      of replicas) or replica (if the number of replicas is less than the number\n      of hosts. With the per-core input pipeline configuration, it is invoked\n      once for each core. With a global batch size `train_batch_size` in\n      `TPUEstimator` constructor, the batch size for each shard is\n      `train_batch_size` // #hosts in the `True` or `PER_HOST_V1` mode. In\n      `PER_HOST_V2` mode, it is `train_batch_size` // #cores. In `BROADCAST`\n      mode, `input_fn` is only invoked once on host 0 and the tensors are\n      broadcasted to all other replicas. The batch size equals to\n      `train_batch_size`. With the per-core input pipeline configuration, the\n      shard batch size is also `train_batch_size` // #cores.\n      Note: per_host_input_for_training==PER_SHARD_V1 only supports mode.TRAIN.\n    tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred\n      within TPUEstimator, however when using ClusterSpec propagation in more\n      esoteric cluster configurations, you may need to specify the job name as a\n      string.\n    initial_infeed_sleep_secs: The number of seconds the infeed thread should\n      wait before enqueueing the first batch. This helps avoid timeouts for\n      models that require a long compilation time.\n    input_partition_dims: A nested list to describe the partition dims for all\n      the tensors from input_fn(). The structure of input_partition_dims must\n      match the structure of `features` and `labels` from input_fn(). The total\n      number of partitions must match\n      `num_cores_per_replica`. For example, if input_fn() returns two tensors:\n        images with shape [N, H, W, C] and labels [N]. input_partition_dims =\n        [[1, 2, 2, 1], None] will split the images to 4 pieces and feed into 4\n        TPU cores. labels tensor are directly broadcasted to all the TPU cores\n        since the partition dims is `None`.\n      Current limitations: This feature is only supported with the PER_HOST_V2\n        input mode.\n    eval_training_input_configuration: If `SLICED`, `input_fn` is only invoked\n      once on host 0 and the tensors are broadcasted to all other replicas.\n      Unlike per_host_input_for_training=BROADCAST, each replica will only get a\n      slice of the data instead of a whole copy. If `PER_HOST_V1`, the behaviour\n      is determined by per_host_input_for_training.\n    experimental_host_call_every_n_steps: Within a training loop, this argument\n      sets how often host calls are performed during training. Host calls will\n      be evaluated every n steps within a training loop where n is the value of\n      this argument.\n    experimental_allow_per_host_v2_parallel_get_next: When enabled, allows\n      concurrent execution of dataset get next calls when using PER_HOST_V2\n      input. May result in a performance increase for models with a small step\n      time, but as a consequence TPUEstimator may non-deterministically\n      distribute batches to different cores, rather than guaranteeing round\n      robin behavior.\n    experimental_feed_hook: This is a class which user can provide to the TPU\n      estimator to override the default TPUInfeedOutfeedSessionHook implementation\n      and add customized implementatioin to handle infeed outfeed logic. If\n      given class is None, TPU estimator uses default TPUInfeedOutfeedSessionHook\n      implementation in tpu_estimator.py. If not None, TPU estimator uses this\n      customized tpu infeed outfeed session hook class rather to override the\n      default one.\n\n  Raises:\n      ValueError: If `num_cores_per_replica` is not 1, 2, 4, 8, ..., 128.\n\n  @compatibility(TF2)\n  TPU Estimator manages its own TensorFlow graph and session, so it is not\n  compatible with TF2 behaviors. We recommend that you migrate to the newer\n  `tf.distribute.TPUStrategy`. See the\n  [TPU guide](https://www.tensorflow.org/guide/tpu) for details.\n  @end_compatibility\n  \"\"\"\n\n  def __new__(cls,\n              iterations_per_loop=2,\n              num_shards=None,\n              num_cores_per_replica=None,\n              per_host_input_for_training=True,\n              tpu_job_name=None,\n              initial_infeed_sleep_secs=None,\n              input_partition_dims=None,\n              eval_training_input_configuration=InputPipelineConfig.PER_HOST_V1,\n              experimental_host_call_every_n_steps=1,\n              experimental_allow_per_host_v2_parallel_get_next=False,\n              experimental_feed_hook=None):\n\n    # Check iterations_per_loop.\n    util_lib.parse_iterations_per_loop(iterations_per_loop)\n\n    # Check num_shards.\n    if num_shards is not None:\n      util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards')\n\n    if input_partition_dims is not None:\n      if len(input_partition_dims) != 1 and len(input_partition_dims) != 2:\n        raise ValueError(\n            'input_partition_dims must be a list/tuple with one or two'\n            ' elements.')\n\n      if per_host_input_for_training is not InputPipelineConfig.PER_HOST_V2:\n        raise ValueError(\n            'input_partition_dims is only supported in PER_HOST_V2 mode.')\n\n      if num_cores_per_replica is None:\n        raise ValueError(\n            'input_partition_dims requires setting num_cores_per_replica.')\n\n    # Check num_cores_per_replica\n    if num_cores_per_replica is not None:\n      if num_cores_per_replica not in ([1, 2, 4, 8, 16, 32, 64, 128]):\n        raise ValueError(\n            'num_cores_per_replica must be 1, 2, 4, 8, 16, 32, 64, 128; '\n            'got {}'.format(str(num_cores_per_replica)))\n\n    if eval_training_input_configuration not in [\n        InputPipelineConfig.PER_HOST_V1, InputPipelineConfig.SLICED\n    ]:\n      raise ValueError(\n          'eval_training_input_configuration must be PER_HOST_V1 or SLICED;'\n          ' got {}'.format(str(eval_training_input_configuration)))\n\n    # per_host_input_for_training may be True, False, or integer in [1..3].\n    # Map legacy values (True, False) to numeric values.\n    if per_host_input_for_training is False:\n      per_host_input_for_training = InputPipelineConfig.PER_SHARD_V1\n    elif per_host_input_for_training is True:\n      per_host_input_for_training = InputPipelineConfig.PER_HOST_V1\n\n    # Check initial_infeed_sleep_secs.\n    if initial_infeed_sleep_secs:\n      util_lib.check_positive_integer(initial_infeed_sleep_secs,\n                                      'TPUConfig initial_infeed_sleep_secs')\n\n    tpu_job_name = tpu_job_name or _get_tpu_job_name_from_tf_config()\n\n    return super(TPUConfig, cls).__new__(\n        cls,\n        iterations_per_loop=iterations_per_loop,\n        num_shards=num_shards,\n        num_cores_per_replica=num_cores_per_replica,\n        per_host_input_for_training=per_host_input_for_training,\n        tpu_job_name=tpu_job_name,\n        initial_infeed_sleep_secs=initial_infeed_sleep_secs,\n        input_partition_dims=input_partition_dims,\n        eval_training_input_configuration=eval_training_input_configuration,\n        experimental_host_call_every_n_steps=(\n            experimental_host_call_every_n_steps),\n        experimental_allow_per_host_v2_parallel_get_next=(\n            experimental_allow_per_host_v2_parallel_get_next),\n        experimental_feed_hook=(experimental_feed_hook))\n\n\n@estimator_export(v1=['estimator.tpu.RunConfig'])\nclass RunConfig(run_config_lib.RunConfig):\n  \"\"\"RunConfig with TPU support.\"\"\"\n\n  def __init__(self,\n               tpu_config=None,\n               evaluation_master=None,\n               master=None,\n               cluster=None,\n               **kwargs):\n    \"\"\"Constructs a RunConfig.\n\n    Args:\n      tpu_config: the TPUConfig that specifies TPU-specific configuration.\n      evaluation_master: a string. The address of the master to use for eval.\n        Defaults to master if not set.\n      master: a string. The address of the master to use for training.\n      cluster: a ClusterResolver\n      **kwargs: keyword config parameters.\n\n    Raises:\n      ValueError: if cluster is not None and the provided session_config has a\n        cluster_def already.\n\n    @compatibility(TF2)\n    TPU Estimator manages its own TensorFlow graph and session, so it is not\n    compatible with TF2 behaviors. We recommend that you migrate to the newer\n    `tf.distribute.TPUStrategy`. See the\n    [TPU guide](https://www.tensorflow.org/guide/tpu) for details.\n    @end_compatibility\n    \"\"\"\n    super(RunConfig, self).__init__(**kwargs)\n    self._tpu_config = tpu_config or TPUConfig()\n    self._cluster = cluster\n\n    # If user sets master and/or evaluation_master explicitly, including empty\n    # string '', take it. Otherwise, take the values set by parent class.\n    if master is not None:\n      if cluster is not None:\n        raise ValueError('Both master and cluster are set.')\n      self._master = master\n    else:\n      if cluster:\n        self._master = cluster.master()\n\n    if evaluation_master is not None:\n      self._evaluation_master = evaluation_master\n    elif (not self._evaluation_master and\n          self.task_type != run_config_lib.TaskType.EVALUATOR):\n      # If the task type is EVALUATOR, it means some cluster manager sets the\n      # TF_CONFIG. In that case, we respect the configuration in TF_CONFIG.\n      #\n      # Otherwise, it means user executes the code without external cluster\n      # manager. For that, we optimize the user experience by setting\n      # evaluation_master to master, unless user overwrites it.\n      self._evaluation_master = self._master\n\n    # Set the ClusterSpec to use\n    if cluster:\n      self._cluster_spec = cluster.cluster_spec()\n\n      # Merge the cluster_def into the ConfigProto.\n      if self._session_config is None:  # pylint: disable=access-member-before-definition\n        self._session_config = tf.compat.v1.ConfigProto(\n            allow_soft_placement=True, isolate_session_state=True)\n      if self._session_config.HasField('cluster_def'):\n        raise ValueError('You cannot provide a ClusterResolver and '\n                         'session_config.cluster_def.')\n      if self._cluster_spec:\n        self._session_config.cluster_def.CopyFrom(\n            self._cluster_spec.as_cluster_def())\n\n  def _maybe_overwrite_session_config_for_distributed_training(self):\n    # Overrides the parent class session_config overwrite for between-graph. TPU\n    # runs with in-graph, which should not have device filter. Doing nothing\n    # (\"pass\") basically disables it.\n    pass\n\n  @property\n  def evaluation_master(self):\n    return self._evaluation_master\n\n  @property\n  def master(self):\n    return self._master\n\n  @property\n  def tpu_config(self):\n    return self._tpu_config\n\n  @property\n  def cluster(self):\n    return self._cluster\n\n  def replace(self, **kwargs):\n    if 'tpu_config' not in kwargs:\n      return super(RunConfig, self).replace(**kwargs)\n\n    tpu_config = kwargs.pop('tpu_config')\n    new_instance = super(RunConfig, self).replace(**kwargs)\n    new_instance._tpu_config = tpu_config  # pylint: disable=protected-access\n    return new_instance\n\n\ndef _get_tpu_job_name_from_tf_config():\n  \"\"\"Extracts the TPU job name from TF_CONFIG env variable.\"\"\"\n  # TODO(xiejw): Extends this to support both TF_CONFIG env variable and cluster\n  # spec propagation.\n  tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV, '{}'))\n  tpu_job_name = tf_config.get(_SERVICE_KEY, {}).get(_TPU_WORKER_JOB_NAME)\n  if tpu_job_name:\n    tf.compat.v1.logging.info('Load TPU job name from TF_CONFIG: %s',\n                              tpu_job_name)\n  return tpu_job_name\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_config_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"TPU RunConfig tests.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport json\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config as tpu_config_lib\nfrom tensorflow_estimator.python.estimator.tpu import util as util_lib\n\n\ndef _set_tf_config_env_variable(tf_config):\n  return tf.compat.v1.test.mock.patch.dict('os.environ',\n                                           {'TF_CONFIG': json.dumps(tf_config)})\n\n\nclass TPURunConfigTest(tf.test.TestCase):\n\n  def test_no_session_config_set_in_local_case(self):\n    run_config = tpu_config_lib.RunConfig()\n    self.assertIsNone(run_config.session_config)\n\n  def test_no_session_config_overwrite_in_local_case(self):\n    session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)\n    run_config = tpu_config_lib.RunConfig(session_config=session_config)\n    self.assertEqual(session_config, run_config.session_config)\n\n  def test_no_session_config_set_with_cluster_spec(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host3:3'],\n            run_config_lib.TaskType.WORKER: ['host3:4']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0\n        }\n    }\n    with _set_tf_config_env_variable(tf_config):\n      run_config = tpu_config_lib.RunConfig()\n      self.assertIsNone(run_config.session_config)\n\n  def test_no_session_config_overwrite_with_cluster_spec(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host3:3'],\n            run_config_lib.TaskType.WORKER: ['host3:4']\n        },\n        'task': {\n            'type': run_config_lib.TaskType.CHIEF,\n            'index': 0\n        }\n    }\n    with _set_tf_config_env_variable(tf_config):\n      session_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)\n      run_config = tpu_config_lib.RunConfig(session_config=session_config)\n      self.assertEqual(session_config, run_config.session_config)\n\n  def test_fail_with_invalid_num_shards(self):\n    with self.assertRaisesRegexp(ValueError, 'must be positive'):\n      tpu_config_lib.RunConfig(\n          tpu_config=tpu_config_lib.TPUConfig(num_shards=0))\n\n  def _validate_invalid_iterations_per_loop(self, iterations_per_loop):\n    with self.assertRaisesRegexp(ValueError, 'must be positive'):\n      tpu_config_lib.RunConfig(\n          tpu_config=tpu_config_lib.TPUConfig(\n              iterations_per_loop=iterations_per_loop))\n\n  def test_fail_with_iterations_per_loop(self):\n    self._validate_invalid_iterations_per_loop(0)\n    self._validate_invalid_iterations_per_loop(-1)\n    self._validate_invalid_iterations_per_loop('-1h')\n    self._validate_invalid_iterations_per_loop('-1m')\n    self._validate_invalid_iterations_per_loop('-1s')\n\n  def test_fail_with_invalid_num_cores_per_replica(self):\n    with self.assertRaisesRegexp(\n        ValueError, 'num_cores_per_replica must be 1, 2, 4, 8, 16, 32, 64, 128;'\n        ' got 7'):\n      tpu_config_lib.TPUConfig(num_cores_per_replica=7)\n\n  def _evaluate_iterations_per_loop_in_seconds(self, value, expected_value,\n                                               expected_unit):\n    config = tpu_config_lib.RunConfig(\n        tpu_config=tpu_config_lib.TPUConfig(iterations_per_loop=value))\n    self.assertEqual(config.tpu_config.iterations_per_loop, value)\n    d = util_lib.parse_iterations_per_loop(\n        config.tpu_config.iterations_per_loop)\n    self.assertEqual(expected_value, d.value)\n    self.assertEqual(expected_unit, d.unit)\n\n  def test_valid_iterations_per_loop(self):\n    self._evaluate_iterations_per_loop_in_seconds(1, 1, 'count')\n    self._evaluate_iterations_per_loop_in_seconds(100, 100, 'count')\n    self._evaluate_iterations_per_loop_in_seconds('300s', 300, 'seconds')\n    self._evaluate_iterations_per_loop_in_seconds('1m', 60, 'seconds')\n    self._evaluate_iterations_per_loop_in_seconds('1h', 3600, 'seconds')\n\n\nclass TPURunConfigMasterTest(tf.test.TestCase):\n\n  def test_default_values(self):\n    run_config = tpu_config_lib.RunConfig()\n    self.assertEqual('', run_config.master)\n    self.assertEqual('', run_config.evaluation_master)\n\n  def test_user_provided_master_and_evaluation_master(self):\n    run_config = tpu_config_lib.RunConfig(\n        master='_master_123', evaluation_master='_eval_master_123')\n    self.assertEqual('_master_123', run_config.master)\n    self.assertEqual('_eval_master_123', run_config.evaluation_master)\n\n  def test_evaluation_master_defaults_to_master(self):\n    run_config = tpu_config_lib.RunConfig(master='_master_123')\n    self.assertEqual('_master_123', run_config.master)\n    self.assertEqual('_master_123', run_config.evaluation_master)\n\n  def test_tf_config(self):\n    tf_config = {\n        'session_master': '_master_123',\n        'eval_session_master': '_eval_master_123'\n    }\n    with _set_tf_config_env_variable(tf_config):\n      run_config = tpu_config_lib.RunConfig()\n      self.assertEqual('_master_123', run_config.master)\n      self.assertEqual('_eval_master_123', run_config.evaluation_master)\n\n  def test_evaluation_master_defaults_to_master_in_tf_config(self):\n    tf_config = {\n        'session_master': '_master_123',\n    }\n    with _set_tf_config_env_variable(tf_config):\n      run_config = tpu_config_lib.RunConfig()\n      self.assertEqual('_master_123', run_config.master)\n      self.assertEqual('_master_123', run_config.evaluation_master)\n\n  def test_respect_evaluation_master_in_tf_config(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n        },\n        'task': {\n            'type': run_config_lib.TaskType.EVALUATOR,\n            'index': 0\n        },\n    }\n    with _set_tf_config_env_variable(tf_config):\n      run_config = tpu_config_lib.RunConfig(master='_something')\n      self.assertEqual('', run_config.evaluation_master)\n\n  def test_user_overwrites_tf_config(self):\n    tf_config = {\n        'session_master': '_master_123',\n        'eval_session_master': '_eval_master_123'\n    }\n    with _set_tf_config_env_variable(tf_config):\n      run_config = tpu_config_lib.RunConfig(\n          master='_new_master_123', evaluation_master='_new_eval_master_123')\n      self.assertEqual('_new_master_123', run_config.master)\n      self.assertEqual('_new_eval_master_123', run_config.evaluation_master)\n\n  def test_user_overwrites_master_in_tf_config(self):\n    tf_config = {\n        'session_master': '_master_123',\n        'eval_session_master': '_eval_master_123'\n    }\n    with _set_tf_config_env_variable(tf_config):\n      run_config = tpu_config_lib.RunConfig(master='_new_master_123')\n      self.assertEqual('_new_master_123', run_config.master)\n      self.assertEqual('_eval_master_123', run_config.evaluation_master)\n\n\nclass TPUJobNameTest(tf.test.TestCase):\n\n  def test_default_name(self):\n    config = tpu_config_lib.RunConfig()\n    self.assertIsNone(config.tpu_config.tpu_job_name)\n\n  def test_with_tf_config(self):\n    tf_config = {'service': {'tpu_worker_job_name': '_my_new_name',}}\n    with _set_tf_config_env_variable(tf_config):\n      config = tpu_config_lib.RunConfig()\n      self.assertEqual('_my_new_name', config.tpu_config.tpu_job_name)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_context.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ===================================================================\n\"\"\"TPU system metadata and associated tooling.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom contextlib import contextmanager\nimport copy\nimport tensorflow as tf\nfrom tensorflow.python.distribute import distribute_lib\nfrom tensorflow.python.ops import summary_ops_v2\nfrom tensorflow.python.tpu import device_assignment as tpu_device_assignment\nfrom tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.tpu import _tpu_estimator_embedding\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\n\n_DEFAULT_JOB_NAME = 'tpu_worker'\n_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'\n_LOCAL_MASTERS = ('', 'local')\n# TODO(pgavin): support PF 3D mesh\n_NUM_CORES_TO_COMPUTATION_SHAPE = {\n    1: [1, 1, 1, 1],\n    2: [1, 1, 1, 2],\n    4: [1, 2, 1, 2],\n    8: [2, 2, 1, 2],\n    16: [4, 2, 1, 2],\n    32: [4, 4, 1, 2],\n    64: [8, 4, 1, 2],\n    128: [8, 8, 1, 2],\n}\n\n\nclass TPUContext(object):\n  \"\"\"A context that holds the current configuration of the TPU computation.\n\n  TPUContext was designed for getting TPU context information when calling\n  input_fn. It can be called in model_fn as well.\n\n  User is not expected to construct the instance from constructor. The only\n  legitimate way to get the instance is either in `input_fn`:\n\n  ```\n  def input_fn(params):\n    batch_size = params['batch_size']\n    context = params['context']\n    # ...\n  ```\n\n  or in `model_fn`\n\n  ```\n  def model_fn(params):\n    batch_size = params['batch_size']\n    context = params['context']\n    # ...\n  ```\n\n  Most of the fields of TPUContext are useful for both `input_fn` and\n  `model_fn`. Exceptions are:\n\n  1. `input_fn` only:\n\n    current_input_fn_deployment\n    current_host\n\n  2. `model_fn` only:\n\n    device_assignment\n\n  \"\"\"\n\n  def __init__(self,\n               internal_ctx,\n               input_device=None,\n               invocation_index=None,\n               call_from_input_fn=True,\n               host_id=None):\n    self._internal_ctx = internal_ctx\n    self._input_device = input_device\n    self._invocation_index = invocation_index\n    self._call_from_input_fn = call_from_input_fn\n    self._host_id = host_id\n\n  def current_input_fn_deployment(self):\n    \"\"\"The configuration of the current input_fn invocation.\n\n    The configuration depends on `TPUConfig.per_host_input_for_training`. See\n    `TPUConfig` for details.\n\n    Only set in params dict of input_fn\n\n    Returns:\n      A tuple of\n        1. Device spec string: String, is the current CPU host where the\n           input_fn is invoked.\n        2. Current invocation index: Int, 0-based index of the input_fn\n           invocation. See next item for details.\n        3. Total invocation count: Int, the total number of times to invoke the\n           input_fn on all CPU hosts. Each invocation will be passed with a new\n           `TPUContext` instance with current invocation index set properly.\n        4. Total number of replicas consumed by current_invocation: Int, the\n           number of replicas fed by the data returned by current input_fn. For\n           example, for per_core input pipeline deployment\n           and non-model-parallelism, total invocation count is equal to\n           the number of cores in the system and num replicas consumed by\n           current invocation is 1. For per-host v2 input pipeline deployment,\n           total invocation count is equal to the number of hosts in the system\n           and num replicas consumed by current invocation is equal to number of\n           replicas per host.\n\n    Raises:\n      RuntimeError: If this method is not be called from input_fn.\n    \"\"\"\n    if not self._call_from_input_fn:\n      raise RuntimeError('This TPUContext instance must not be called from'\n                         ' model_fn.')\n\n    if self._internal_ctx.is_input_sharded_per_core():\n      total_invocation_count = (\n          self._internal_ctx.num_hosts *\n          self._internal_ctx.num_of_replicas_per_host)\n      replicas_consumed = 1\n    elif self._internal_ctx.is_input_broadcast_with_iterators():\n      total_invocation_count = 1\n      replicas_consumed = self._internal_ctx.num_replicas\n    elif self._internal_ctx.is_replica_across_hosts():\n      total_invocation_count = self._internal_ctx.num_replicas\n      replicas_consumed = 1\n    else:\n      total_invocation_count = self._internal_ctx.num_hosts\n      replicas_consumed = self._internal_ctx.num_of_replicas_per_host\n    return (self._input_device, self._invocation_index, total_invocation_count,\n            replicas_consumed)\n\n  @property\n  def num_replicas(self):\n    \"\"\"The total number of replicas.\n\n    For non-model-parallelism, num_replicas should be the total num of TPU\n    cores in the system.\n\n    Returns:\n      The number of replicas.\n    \"\"\"\n    return self._internal_ctx.num_replicas\n\n  @property\n  def num_hosts(self):\n    \"\"\"The number of hosts for the TPU system.\"\"\"\n    return self._internal_ctx.num_hosts\n\n  @property\n  def current_host(self):\n    \"\"\"The current host index for the TPU system.\n\n    Returns:\n      The host index (int).\n\n    Raises:\n      RuntimeError: If this method is not be called from input_fn.\n    \"\"\"\n\n    if not self._call_from_input_fn:\n      raise RuntimeError('This TPUContext instance must not be called from'\n                         ' model_fn.')\n\n    return self._host_id\n\n  @property\n  def num_of_replicas_per_host(self):\n    \"\"\"The number of replicas for each host.\"\"\"\n    if self._internal_ctx.model_parallelism_enabled:\n      raise ValueError(\n          'num_of_replicas_per_host is not supported for model_parallelism')\n    return self._internal_ctx.num_of_replicas_per_host\n\n  @property\n  def device_assignment(self):\n    \"\"\"Returns device_assignment object.\n\n    Raises:\n      RuntimeError: If this method is not be called from model_fn.\n    \"\"\"\n    if self._call_from_input_fn:\n      raise RuntimeError('This TPUContext instance must not be called from'\n                         ' input_fn.')\n    return self._internal_ctx.device_assignment\n\n  def device_for_replica(self, replica_id):\n    \"\"\"Returns the tuple of (CPU device and device ordinal) for replica.\n\n    This should be used for full replicate for non-model-parallelism.\n\n    Args:\n       replica_id: Int, the replica index.\n\n    Returns:\n       A tuple of device spec for CPU device and int device ordinal.\n    \"\"\"\n    # Note that: For the non-model parallelism, the mapping could be\n    # a random permutation. The order should not matter in most cases\n    # as far as model is replicated to all cores in the system.\n    return self._internal_ctx.device_for_replica(replica_id)\n\n  @property\n  def tpu_host_placement_function(self):\n    \"\"\"Returns the TPU host place function.\n\n    The place function takes host_id as the input and returns the TF device\n    for the correspoding host.\n    \"\"\"\n\n    def _placement_function(host_id):\n      \"\"\"Return the host device given host_id.\"\"\"\n      return self._internal_ctx.tpu_host_placement_function(host_id=host_id)\n\n    return _placement_function\n\n\nclass _InternalTPUContext(object):\n  \"\"\"A context holds immutable states of TPU computation.\n\n  This immutable object holds TPUEstimator config, train/eval batch size, and\n  `TPUEstimator.use_tpu`, which is expected to be passed around. It also\n  provides utility functions, based on the current state, to determine other\n  information commonly required by TPU computation, such as TPU device names,\n  TPU hosts, shard batch size, etc.\n\n  if eval_on_tpu is False, then execution of eval on TPU is disabled.\n  if eval_on_tpu is True, but use_tpu is False, a warning is issued,\n  and TPU execution is disabled for all modes.\n\n  N.B. As `mode` is not immutable state in Estimator, but essential to\n  distinguish between TPU training and evaluation, a common usage for\n  _InternalTPUContext with `mode` is as follows:\n  ```\n  with _ctx.with_mode(mode) as ctx:\n    if ctx.is_running_on_cpu():\n       ...\n  ```\n  \"\"\"\n\n  def __init__(self,\n               config,\n               train_batch_size,\n               eval_batch_size,\n               predict_batch_size,\n               use_tpu,\n               eval_on_tpu=True,\n               embedding_config_spec=None):\n    self._config = config\n    self._train_batch_size = train_batch_size\n    self._eval_batch_size = eval_batch_size\n    self._predict_batch_size = predict_batch_size\n    self._use_tpu = use_tpu\n    tf.compat.v1.logging.info('_TPUContext: eval_on_tpu %s', eval_on_tpu)\n    if not use_tpu and eval_on_tpu:\n      tf.compat.v1.logging.warn('eval_on_tpu ignored because use_tpu is False.')\n\n    self._eval_on_tpu = eval_on_tpu\n    self._model_parallelism_enabled = (\n        use_tpu and config.tpu_config.num_cores_per_replica)\n    self._mode = None\n    num_cores_per_replica = config.tpu_config.num_cores_per_replica\n    if self._model_parallelism_enabled:\n      self._computation_shape = _NUM_CORES_TO_COMPUTATION_SHAPE[\n          num_cores_per_replica]\n    else:\n      self._computation_shape = None\n    self._lazy_tpu_system_metadata_dict = {}  # key by master address\n    self._lazy_device_assignment_dict = {}  # key by master address\n    self._lazy_validation_dict = {}  # key by ModeKeys\n    self._embedding_config_spec = embedding_config_spec\n    self._lazy_embedding_config_dict = {}  # key by master address\n\n  def _assert_mode(self):\n    if self._mode is None:\n      raise RuntimeError(\n          '`mode` needs to be set via contextmanager `with_mode`.')\n    return self._mode\n\n  @contextmanager\n  def with_mode(self, mode):\n    # NOTE(xiejw): Shallow copy is enough. It will share he lazy dictionaries,\n    # such as _lazy_tpu_system_metadata_dict between new copy and the original\n    # one. Note that all lazy states stored in properties _lazy_foo are sort of\n    # immutable as they should be same for the process lifetime.\n    new_ctx = copy.copy(self)\n    new_ctx._mode = mode  # pylint: disable=protected-access\n    yield new_ctx\n\n  @property\n  def mode(self):\n    return self._assert_mode()\n\n  def _get_master_address(self):\n    mode = self._assert_mode()\n    config = self._config\n    master = (\n        config.master\n        if mode != model_fn_lib.ModeKeys.EVAL else config.evaluation_master)\n    return master\n\n  def _get_tpu_system_metadata(self):\n    \"\"\"Gets the (maybe cached) TPU system metadata.\"\"\"\n    master = self._get_master_address()\n    tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)\n    if tpu_system_metadata is not None:\n      return tpu_system_metadata\n\n    cluster_def = None\n    if (self._config.session_config and\n        self._config.session_config.cluster_def.job):\n      cluster_def = self._config.session_config.cluster_def\n\n    # pylint: disable=protected-access\n    tpu_system_metadata = (\n        tpu_system_metadata_lib._query_tpu_system_metadata(\n            master,\n            cluster_def=cluster_def,\n            query_topology=self.model_parallelism_enabled))\n\n    self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata\n    return tpu_system_metadata\n\n  def _get_device_assignment(self):\n    \"\"\"Gets the (maybe cached) TPU device assignment.\"\"\"\n    master = self._get_master_address()\n    device_assignment = self._lazy_device_assignment_dict.get(master)\n    if device_assignment is not None:\n      return device_assignment\n\n    tpu_system_metadata = self._get_tpu_system_metadata()\n\n    device_assignment = tpu_device_assignment.device_assignment(\n        tpu_system_metadata.topology,\n        computation_shape=self._computation_shape,\n        num_replicas=self.num_replicas)\n\n    tf.compat.v1.logging.info(\n        'num_cores_per_replica: %s',\n        str(self._config.tpu_config.num_cores_per_replica))\n    tf.compat.v1.logging.info('computation_shape: %s',\n                              str(self._computation_shape))\n    tf.compat.v1.logging.info('num_replicas: %d', self.num_replicas)\n    tf.compat.v1.logging.info(\n        'device_assignment.topology.device_coordinates: %s',\n        str(device_assignment.topology.device_coordinates))\n    tf.compat.v1.logging.info('device_assignment.core_assignment: %s',\n                              str(device_assignment.core_assignment))\n\n    self._lazy_device_assignment_dict[master] = device_assignment\n    return device_assignment\n\n  @property\n  def tensor_core_embedding_columns(self):\n    if self._embedding_config_spec:\n      return self._embedding_config_spec.tensor_core_feature_columns\n    return None\n\n  @property\n  def embedding_config(self):\n    \"\"\"Returns the embedding config based on current mode.\"\"\"\n    master = self._get_master_address()\n    if master in self._lazy_embedding_config_dict:\n      embedding_config = self._lazy_embedding_config_dict[master]\n    else:\n      embedding_config = None\n      if self._use_tpu and self._embedding_config_spec:\n        embedding_config = _tpu_estimator_embedding.EmbeddingConfig(\n            self._embedding_config_spec, self._train_batch_size,\n            self._eval_batch_size, self.num_hosts, self.num_cores, self.config)\n        if not embedding_config.has_embedding_tables():\n          embedding_config = None\n      self._lazy_embedding_config_dict[master] = embedding_config\n\n    if embedding_config is not None:\n      mode = self._assert_mode()\n      # Dynamically attach tpu_embedding based on mode. With\n      # this, we could keep embedding_config immutable but call site always\n      # accesses the unified API '.tpu_embedding'.\n      embedding_config.tpu_embedding = embedding_config.get_tpu_embedding(mode)\n    return embedding_config\n\n  @property\n  def allow_per_host_v2_parallel_get_next(self):\n    return (self._config.tpu_config\n            .experimental_allow_per_host_v2_parallel_get_next)\n\n  @property\n  def feed_hook(self):\n    return (self._config.tpu_config.experimental_feed_hook)\n\n  @property\n  def model_parallelism_enabled(self):\n    return self._model_parallelism_enabled\n\n  @property\n  def input_partition_dims(self):\n    return self._config.tpu_config.input_partition_dims\n\n  @property\n  def device_assignment(self):\n    return (self._get_device_assignment()\n            if self._model_parallelism_enabled else None)\n\n  @property\n  def num_of_cores_per_host(self):\n    metadata = self._get_tpu_system_metadata()\n    return metadata.num_of_cores_per_host\n\n  @property\n  def num_cores(self):\n    metadata = self._get_tpu_system_metadata()\n    return metadata.num_cores\n\n  @property\n  def num_of_replicas_per_host(self):\n    \"\"\"Return the number of replicas per host.\"\"\"\n    if self.model_parallelism_enabled:\n      # There can be fewer replicas. This might return 0!\n      return self.num_replicas // self.num_hosts\n    else:\n      return self.num_of_cores_per_host\n\n  @property\n  def num_replicas(self):\n    \"\"\"Compute the total number of replicas.\"\"\"\n    num_cores_in_system = self.num_cores\n\n    if self.model_parallelism_enabled:\n      num_cores_per_replica = self._config.tpu_config.num_cores_per_replica\n      if num_cores_per_replica > num_cores_in_system:\n        raise ValueError(\n            'The num of cores required by the model parallelism, specified by '\n            'TPUConfig.num_cores_per_replica, is larger than the total num of '\n            'TPU cores in the system. num_cores_per_replica: {}, num cores '\n            'in the system: {}'.format(num_cores_per_replica,\n                                       num_cores_in_system))\n\n      if num_cores_in_system % num_cores_per_replica != 0:\n        raise RuntimeError(\n            'The num of cores in the system ({}) is not divisible by the num '\n            'of cores ({}) required by the model parallelism, specified by '\n            'TPUConfig.num_cores_per_replica. This should never happen!'.format(\n                num_cores_in_system, num_cores_per_replica))\n\n      return num_cores_in_system // num_cores_per_replica\n    else:\n      return num_cores_in_system\n\n  @property\n  def num_hosts(self):\n    metadata = self._get_tpu_system_metadata()\n    return metadata.num_hosts\n\n  @property\n  def config(self):\n    return self._config\n\n  def is_input_sharded_per_core(self):\n    \"\"\"Return true if input_fn is invoked per-core (other than per-host).\"\"\"\n    mode = self._assert_mode()\n    return (mode == model_fn_lib.ModeKeys.TRAIN and\n            (self._config.tpu_config.per_host_input_for_training is\n             tpu_config.InputPipelineConfig.PER_SHARD_V1))\n\n  def is_input_per_host_with_iterators(self):\n    \"\"\"Return true if input_fn should be run in the per-host v2 config.\"\"\"\n    return (self._config.tpu_config.per_host_input_for_training is\n            tpu_config.InputPipelineConfig.PER_HOST_V2)\n\n  def is_input_broadcast_with_iterators(self):\n    \"\"\"Return true if input_fn should be run in the full_replicae config.\"\"\"\n    return ((self._config.tpu_config.per_host_input_for_training is\n             tpu_config.InputPipelineConfig.BROADCAST) or\n            (self.is_input_slice_broadcast_to_all_cores()))\n\n  def is_input_slice_broadcast_to_all_cores(self):\n    \"\"\"Return true if input_fn is invoked once and broadcast to other hosts.\"\"\"\n    mode = self._assert_mode()\n    return (mode != model_fn_lib.ModeKeys.TRAIN and\n            self._config.tpu_config.eval_training_input_configuration is\n            tpu_config.InputPipelineConfig.SLICED)\n\n  def is_replica_across_hosts(self):\n    \"\"\"Return true if single replica is across multiple hosts.\"\"\"\n    # For example, when num_cores_per_replica > num_cores_per_host.\n    num_cores_per_replica = self._config.tpu_config.num_cores_per_replica\n    num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host\n    return (num_cores_per_replica is not None and\n            num_cores_per_replica > num_cores_per_host)\n\n  def is_running_on_cpu(self, is_export_mode=False):\n    \"\"\"Determines whether the input_fn and model_fn should be invoked on CPU.\n\n    This API also validates user provided configuration, such as batch size,\n    according the lazy initialized TPU system metadata.\n\n    Args:\n      is_export_mode: Indicates whether the current mode is for exporting the\n        model, when mode == PREDICT. Only with this bool, we could tell whether\n        user is calling the Estimator.predict or Estimator.export_savedmodel,\n        which are running on TPU and CPU respectively. Parent class Estimator\n        does not distinguish these two.\n\n    Returns:\n      bool, whether current input_fn or model_fn should be running on CPU.\n\n    Raises:\n      ValueError: any configuration is invalid.\n    \"\"\"\n\n    is_running_on_cpu = self._is_running_on_cpu(is_export_mode)\n    if not is_running_on_cpu:\n      self._validate_tpu_configuration()\n    return is_running_on_cpu\n\n  def _is_running_on_cpu(self, is_export_mode):\n    \"\"\"Determines whether the input_fn and model_fn should be invoked on CPU.\"\"\"\n    mode = self._assert_mode()\n\n    if not self._use_tpu:\n      return True\n\n    if mode == model_fn_lib.ModeKeys.EVAL and not self._eval_on_tpu:\n      tf.compat.v1.logging.info('_is_running_on_cpu: eval_on_tpu disabled')\n      return True\n\n    if is_export_mode:\n      return True\n\n    return False\n\n  @property\n  def global_batch_size(self):\n    mode = self._assert_mode()\n    if mode == model_fn_lib.ModeKeys.TRAIN:\n      return self._train_batch_size\n    elif mode == model_fn_lib.ModeKeys.EVAL:\n      return self._eval_batch_size\n    elif mode == model_fn_lib.ModeKeys.PREDICT:\n      return self._predict_batch_size\n    else:\n      return None\n\n  @property\n  def batch_size_for_input_fn(self):\n    \"\"\"Returns the shard batch size for `input_fn`.\"\"\"\n    global_batch_size = self.global_batch_size\n    if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators()):\n      return global_batch_size\n\n    # On TPU\n    if self.is_input_sharded_per_core() or (\n        self.is_input_per_host_with_iterators()) or (\n            self.is_replica_across_hosts()):\n      return global_batch_size // self.num_replicas\n    else:\n      return global_batch_size // self.num_hosts\n\n  @property\n  def batch_size_for_model_fn(self):\n    \"\"\"Returns the shard batch size for `model_fn`.\"\"\"\n    global_batch_size = self.global_batch_size\n\n    if (self.is_running_on_cpu() or self.is_input_broadcast_with_iterators() and\n        not self.is_input_slice_broadcast_to_all_cores()):\n      return global_batch_size\n\n    # On TPU. always sharded per shard.\n    return global_batch_size // self.num_replicas\n\n  @property\n  def master_job(self):\n    \"\"\"Returns the job name to use to place TPU computations on.\n\n    Returns:\n      A string containing the job name, or None if no job should be specified.\n\n    Raises:\n      ValueError: If the user needs to specify a tpu_job_name, because we are\n        unable to infer the job name automatically, or if the user-specified job\n        names are inappropriate.\n    \"\"\"\n    run_config = self._config\n    # If the user specifies the tpu_job_name, use that.\n    if run_config.tpu_config.tpu_job_name:\n      return run_config.tpu_config.tpu_job_name\n\n    # The tpu job is determined by the run_config. Right now, this method is\n    # required as tpu_config is not part of the RunConfig.\n    mode = self._assert_mode()\n    master = (\n        run_config.evaluation_master\n        if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)\n    cluster_def = (\n        run_config.session_config.cluster_def\n        if run_config.session_config else None)\n\n    try:\n      master_job = tpu_system_metadata_lib.master_job(master, cluster_def)\n    except ValueError as e:\n      raise ValueError(\n          str(e) + ' Please specify a tpu_job_name as part of '\n          'your TPUConfig.')\n    return master_job\n\n  @property\n  def tpu_host_placement_function(self):\n    \"\"\"Returns the TPU host place function.\"\"\"\n\n    master = self.master_job\n\n    def _placement_function(_sentinal=None, replica_id=None, host_id=None):  # pylint: disable=invalid-name\n      \"\"\"Return the host device given replica_id or host_id.\"\"\"\n      assert _sentinal is None\n      if replica_id is not None and host_id is not None:\n        raise RuntimeError(\n            'replica_id and host_id can have only one non-None value.')\n\n      if master is None:\n        return '/replica:0/task:0/device:CPU:0'\n      else:\n        if replica_id is not None:\n          if self.model_parallelism_enabled:\n            return self.device_assignment.host_device(\n                replica=replica_id, job=master)\n          else:\n            host_id = replica_id / self.num_of_cores_per_host\n\n        return '/job:%s/task:%d/device:CPU:0' % (master, host_id)\n\n    return _placement_function\n\n  @property\n  def tpu_device_placement_function(self):\n    \"\"\"Returns a TPU device placement Fn.\"\"\"\n    master = self.master_job\n    job_device = '' if master is None else ('/job:%s' % master)\n\n    def _placement_function(i):\n      if self.model_parallelism_enabled:\n        return self.device_assignment.tpu_device(replica=i, job=master)\n      else:\n        num_of_cores_per_host = self.num_of_cores_per_host\n        host_id = i / num_of_cores_per_host\n        ordinal_id = i % num_of_cores_per_host\n        return '%s/task:%d/device:TPU:%d' % (job_device, host_id, ordinal_id)\n\n    return _placement_function\n\n  def tpu_ordinal_function(self, host_id):\n    \"\"\"Returns the TPU ordinal fn.\"\"\"\n\n    def _tpu_ordinal_function(shard_index_in_host):\n      \"\"\"Return the TPU ordinal associated with a shard.\n\n      Required because the enqueue ops are placed on CPU.\n\n      Args:\n        shard_index_in_host: the shard index\n\n      Returns:\n        The ordinal of the TPU device the shard's infeed should be placed on.\n      \"\"\"\n      if self.model_parallelism_enabled:\n        # We put both enqueue/dequeue ops at tpu.core(0) in each replica.\n        replica = self.device_assignment.lookup_replicas(host_id,\n                                                         0)[shard_index_in_host]\n        return self.device_assignment.tpu_ordinal(replica=replica)\n      else:\n        return shard_index_in_host % self.num_of_cores_per_host\n\n    return _tpu_ordinal_function\n\n  def _validate_tpu_configuration(self):\n    \"\"\"Validates the configuration based on the TPU system metadata.\"\"\"\n    mode = self._assert_mode()\n    if self._lazy_validation_dict.get(mode):\n      return\n\n    # All following information is obtained from TPU system metadata.\n    num_cores = self.num_cores\n    num_replicas = self.num_replicas\n    num_hosts = self.num_hosts\n\n    if not num_cores:\n      tpu_system_metadata = self._get_tpu_system_metadata()\n      raise RuntimeError(\n          'Cannot find any TPU cores in the system. Please double check '\n          'Tensorflow master address and TPU worker(s). Available devices '\n          'are {}.'.format(tpu_system_metadata.devices))\n\n    if self._config.tpu_config.num_shards:\n      user_provided_num_replicas = self._config.tpu_config.num_shards\n      if user_provided_num_replicas != num_replicas:\n        message = (\n            'TPUConfig.num_shards is not set correctly. According to TPU '\n            'system metadata for Tensorflow master ({}): num_replicas should '\n            'be ({}), got ({}). For non-model-parallelism, num_replicas should '\n            'be the total num of TPU cores in the system. For '\n            'model-parallelism, the total number of TPU cores should be '\n            'num_cores_per_replica * num_replicas. Please set it '\n            'accordingly or leave it as `None`'.format(\n                self._get_master_address(), num_replicas,\n                user_provided_num_replicas))\n\n        raise ValueError(message)\n\n    if self._config.tpu_config.num_cores_per_replica and (\n        not self.is_input_per_host_with_iterators()):\n      num_cores_per_replica = self._config.tpu_config.num_cores_per_replica\n      num_cores_per_host = self._get_tpu_system_metadata().num_of_cores_per_host\n      if num_cores_per_replica > num_cores_per_host:\n        raise ValueError(\n            'Except the PER_HOST_V2 mode, the num of cores required by '\n            'model parallelism specified by TPUConfig.num_cores_per_replica '\n            'should be less than or equal to the num_cores_per_host. '\n            'num_cores_per_replica: {}, num_cores_per_host: {}'.format(\n                num_cores_per_replica, num_cores_per_host))\n\n    if mode == model_fn_lib.ModeKeys.TRAIN:\n      if (self._train_batch_size % num_replicas != 0 and\n          not self.is_input_broadcast_with_iterators()):\n        raise ValueError(\n            'train batch size {} must be divisible by number of replicas {}'\n            .format(self._train_batch_size, num_replicas))\n\n    elif mode == model_fn_lib.ModeKeys.EVAL:\n      if self._eval_batch_size is None:\n        raise ValueError(\n            'eval_batch_size in TPUEstimator constructor cannot be `None` '\n            'if .evaluate is running on TPU.')\n      if (self._eval_batch_size % num_replicas != 0 and\n          not self.is_input_broadcast_with_iterators()):\n        raise ValueError(\n            'eval batch size {} must be divisible by number of replicas {}'\n            .format(self._eval_batch_size, num_replicas))\n      if (num_hosts != 1 and\n          not self.is_input_broadcast_with_iterators() and\n          not self.is_input_per_host_with_iterators()):\n        raise ValueError(\n            'TPUEstimator.evaluate is only supported under three conditions: '\n            '1. num_hosts=1; 2. BROADCAST mode; '\n            '3. PER_HOST_V2 mode. '\n            'mode: {}; num_hosts: {}; num_replicas=1:{}'.format(\n                self._config.tpu_config.per_host_input_for_training, num_hosts,\n                num_replicas))\n      if num_hosts > 1 and self.is_input_per_host_with_iterators():\n        tf.compat.v1.logging.warn('Running TPUEstimator.evaluate for input mode'\n                                  ' PER_HOST_V2 and num_hosts %d', num_hosts)\n    else:\n      assert mode == model_fn_lib.ModeKeys.PREDICT\n      if self._predict_batch_size is None:\n        raise ValueError(\n            'predict_batch_size in TPUEstimator constructor cannot be `None` '\n            'if .predict is running on TPU.')\n      if (self._predict_batch_size % num_replicas != 0 and\n          not self.is_input_broadcast_with_iterators()):\n        raise ValueError(\n            'predict batch size {} must be divisible by number of replicas {}'\n            .format(self._predict_batch_size, num_replicas))\n      if num_hosts != 1 and not (\n          self.is_input_broadcast_with_iterators()) and not (\n              num_replicas == 1 and self.is_input_per_host_with_iterators()):\n        raise ValueError(\n            'TPUEstimator.predict is only supported under three conditions: '\n            '1. num_hosts=1; 2. BROADCAST mode; '\n            '3. PER_HOST_V2 mode with num_replicas=1. '\n            'mode: {}; num_hosts: {}; num_replicas=1:{}'.format(\n                self._config.tpu_config.per_host_input_for_training, num_hosts,\n                num_replicas))\n\n    # Record the state \"validated\" into lazy dictionary.\n    self._lazy_validation_dict[mode] = True\n\n  def device_for_replica(self, replica_id):\n    \"\"\"Returns the tuple of (CPU device and device ordinal) for replica.\n\n    This should be used for full replicate for non-model-parallelism.\n\n    Args:\n       replica_id: Int, the replica index.\n\n    Returns:\n       A tuple of device spec for CPU device and int device ordinal.\n    \"\"\"\n    master = self.master_job\n\n    if self.model_parallelism_enabled:\n      return (self.device_assignment.host_device(\n          replica=replica_id,\n          job=master), self.device_assignment.tpu_ordinal(replica=replica_id))\n\n    job_device = '' if master is None else ('/job:%s' % master)\n\n    num_of_replicas_per_host = self.num_of_replicas_per_host\n    assert num_of_replicas_per_host > 0, (\n        'Got num_of_replicas_per_host: {}'.format(num_of_replicas_per_host))\n    host_id = replica_id / num_of_replicas_per_host\n    ordinal_id = replica_id % num_of_replicas_per_host\n\n    host_device = '%s/task:%d/device:CPU:0' % (job_device, host_id)\n    return (host_device, ordinal_id)\n\n\nclass _OneCoreTPUContext(_InternalTPUContext):\n  \"\"\"Special _InternalTPUContext for one core usage.\"\"\"\n\n  def __init__(self, config, train_batch_size, eval_batch_size,\n               predict_batch_size, use_tpu):\n\n    super(_OneCoreTPUContext,\n          self).__init__(config, train_batch_size, eval_batch_size,\n                         predict_batch_size, use_tpu)\n\n  def _get_tpu_system_metadata(self):\n    \"\"\"Gets the (maybe cached) TPU system metadata.\"\"\"\n    master = self._get_master_address()\n    tpu_system_metadata = self._lazy_tpu_system_metadata_dict.get(master)\n    if tpu_system_metadata is not None:\n      return tpu_system_metadata\n\n    tpu_system_metadata = (\n        tf.tpu.experimental.TPUSystemMetadata(  # pylint: disable=protected-access\n            num_cores=1,\n            num_hosts=1,\n            num_of_cores_per_host=1,\n            topology=None,\n            devices=[]))\n\n    self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata\n    return tpu_system_metadata\n\n\nclass _TPUEstimatorReplicaContext(tf.distribute.ReplicaContext):\n  \"\"\"Internal context for storing replica id.\n\n  This is to set eager.context.Context() so that only summary ops from\n  0th replica is executed.\n  \"\"\"\n\n  def __init__(self, replica_id_in_sync):\n    \"\"\"Creates internal replica context for TPUEstimator.\n\n    Args:\n      replica_id_in_sync: Zero indexed integer id of replica that is running the\n        TPU compuation.\n    \"\"\"\n    super(_TPUEstimatorReplicaContext, self).__init__(None, replica_id_in_sync)\n    # Use default strategy and replica context when variables are\n    # accessed/watched for backpropagation.\n    # pylint: disable=protected-access\n    self._thread_context = distribute_lib._DefaultReplicaThreadMode(\n    )\n    self._strategy = self._thread_context.strategy\n    # pylint: enable=protected-access\n\n  def __enter__(self):\n\n    def replica_id_is_zero():\n      return tf.math.equal(self.replica_id_in_sync_group, tf.constant(0))\n\n    if hasattr(summary_ops_v2, '_summary_state'):\n      summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access\n      self._summary_recording_distribution_strategy = (\n          summary_state.is_recording_distribution_strategy)\n      summary_state.is_recording_distribution_strategy = replica_id_is_zero\n\n  def __exit__(self, exception_type, exception_value, traceback):\n    if hasattr(summary_ops_v2, '_summary_state'):\n      summary_state = summary_ops_v2._summary_state  # pylint: disable=protected-access\n      summary_state.is_recording_distribution_strategy = (\n          self._summary_recording_distribution_strategy)\n\n\ndef _get_tpu_context(config, train_batch_size, eval_batch_size,\n                     predict_batch_size, use_tpu, eval_on_tpu,\n                     embedding_config_spec):\n  \"\"\"Returns an instance of `_InternalTPUContext`.\"\"\"\n\n  if (config.tpu_config.num_shards == 1 and\n      config.tpu_config.num_cores_per_replica is None):\n    if embedding_config_spec is not None:\n      raise ValueError('Setting TPUConfig.num_shards==1 is unsupported '\n                       'when embedding_config_spec is not None.')\n    tf.compat.v1.logging.warn(\n        'Setting TPUConfig.num_shards==1 is an unsupported behavior. '\n        'Please fix as soon as possible (leaving num_shards as None.)')\n    return _OneCoreTPUContext(config, train_batch_size, eval_batch_size,\n                              predict_batch_size, use_tpu)\n\n  return _InternalTPUContext(config, train_batch_size, eval_batch_size,\n                             predict_batch_size, use_tpu, eval_on_tpu,\n                             embedding_config_spec)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_enqueue_sequence_test.py",
    "content": "# Copyright 2021 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for sequence embedding features using TPU and TPUEstimator.\"\"\"\n\nimport os\nfrom typing import Dict, List, Text, Tuple\nfrom absl import flags\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n\nfrom tensorflow.contrib import summary as contrib_summary\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config as tpu_config_lib\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\n\nFLAGS = flags.FLAGS\n\nclass TPUEnqueueSequenceTest(tf.test.TestCase, parameterized.TestCase):\n\n  def setUp(self):\n    super().setUp()\n    temp_dir = self.get_temp_dir()\n    self._model_dir = os.path.join(temp_dir, 'model_dir')\n    self._summary_dir = os.path.join(temp_dir, 'summaries')\n\n    os.mkdir(self._model_dir)\n    os.mkdir(self._summary_dir)\n\n  # The key in the dataset which holds the sparse IDs. TPUEstimator will pass\n  # the embeddings in the features dictionary arg of model_fn after performing\n  # the embedding lookups.\n  _KEY = 'SparseIDs'\n\n  # The names of the summaries which hold the activations/sequence lengths.\n  _SUMMARY_ACTIVATIONS = 'summary_activations'\n  _SUMMARY_SEQUENCE_LENGTHS = 'summary_sequence_lengths'\n\n  def get_activations_and_sequence_lengths(\n      self,\n      embedding_weights: List[List[float]],\n      sparse_ids: tf.SparseTensorValue,\n      batch_size: int,\n      max_sequence_length: int,\n      dimension: int,\n      combiner: Text = 'mean',\n  ) -> Tuple[tf.Tensor, tf.Tensor]:\n    \"\"\"Gets the activations and seq lengths for a batch of sparse IDs.\n\n    This method uses TPUEstimator and the Feature Column API to get embedding\n    activations for a batch of sparse of sparse IDs using a specified set of\n    embedding weights.\n\n    Args:\n      embedding_weights: The embedding weights as a 2D list of floats.  The\n        outer list length is the vocabulary size of the embedding table.  The\n        inner list length is the dimension of the embedding weights.\n      sparse_ids: The embedding IDs to lookup. This is a 2D SparseTensorValue of\n        shape [batch_size, max_sequence_length].\n      batch_size: The size of the first dimension of sparse_ids.\n      max_sequence_length:  The size of the second dimension of sparse_ids.\n      dimension: The embedding dimension size (number of floats for each\n        embedding ID).\n      combiner: The embedding column combiner (used for multivalent features).\n\n    Returns:\n      A tuple containing:\n        activations:  The activations for the specified sparse_ids.\n          type=float32, shape=[batch_size, max_sequence_length, dimension]\n        sequence_lengths: The sequence length of each example.\n          type=int64. shape=[batch_size].\n    \"\"\"\n\n    vocab_size = len(embedding_weights)\n    categorical_column = (\n        tf.feature_column.sequence_categorical_column_with_identity(\n            key=self._KEY,\n            num_buckets=vocab_size,\n        ))\n\n    # Create embedding column initialized with weights provided by caller.\n    embedding_column = tf.tpu.experimental.embedding_column(\n        categorical_column,\n        dimension=dimension,\n        max_sequence_length=max_sequence_length,\n        initializer=tf.constant_initializer(embedding_weights),\n        combiner=combiner,\n    )\n\n    # Add an SGD optimizer. This choice is arbitrary for computing activations.\n    # It's only required to avoid an undefined gradients error.\n    embedding_opt = tf.tpu.experimental.StochasticGradientDescentParameters(.1)\n    embedding_config_spec = tpu_estimator.EmbeddingConfigSpec(\n        feature_columns=[embedding_column],\n        optimization_parameters=embedding_opt,\n    )\n\n    def _input_fn(params: Dict[Text, int]) -> tf.data.Dataset:\n      \"\"\"Creates a batched dataset containing the sparse_ids as a feature.\"\"\"\n      # Convert sparse IDs to batched dataset.\n      sparse_ids_dataset = tf.data.Dataset.range(1).map(\n          lambda x: {self._KEY: tf.SparseTensor.from_value(sparse_ids)})\n\n      # Unbatch and rebatch the dataset based on the batch_size param from\n      # TPUEstimator. This is necessary for shape validation performed internal\n      # to TPUEstimator.\n      return sparse_ids_dataset.unbatch().repeat().batch(params['batch_size'])\n\n    def _host_call(\n        concat_activations: tf.Tensor,\n        concat_sequence_lengths: tf.Tensor,\n    ) -> List[tf.Operation]:\n      \"\"\"Stores the activations and sequence lengths into a summary.\n\n      TPUEstimator will concat the activations and sequence lengths from the\n      minibatches on each core along axis=0 and pass them to this host call.\n      This host call writes them to a file using the TF summary APIs.\n\n      Args:\n        concat_activations: The activations for the global batch. 2D\n          Tensor(type=float32, shape=[batch_size, max_sequence_length]).\n        concat_sequence_lengths:  The sequence lengths for the global batch. 2D\n          Tensor(type=int64, shape=[batch_size, max_sequence_length]).\n\n      Returns:\n        A list of summary ops for TPUEstimator to run on the host.\n      \"\"\"\n      with contrib_summary.create_file_writer(self._summary_dir).as_default():\n        with contrib_summary.always_record_summaries():\n          contrib_summary.generic(\n              self._SUMMARY_ACTIVATIONS,\n              concat_activations,\n          )\n          contrib_summary.generic(self._SUMMARY_SEQUENCE_LENGTHS,\n                                  concat_sequence_lengths)\n          return contrib_summary.all_summary_ops()\n\n    def _model_fn(\n        features: Dict[Text, tf.Tensor],\n        params: Dict[Text, int],\n        mode: model_fn_lib.ModeKeys,\n    ) -> tpu_estimator.TPUEstimatorSpec:\n      \"\"\"A model which writes activations and sequence lengths to a file.\n\n      This method creates a model to extract the activations and sequence\n      lengths on each TPU core and pass them to a host call which writes them\n      to a file.\n\n      The model also applies an optimizer to the activations simply to avoid an\n      undefined gradients error.\n\n      Args:\n        features: A dictionary mapping keys to tensor inputs.\n        params: Parameters passed by TPUEstimator.\n        mode: Mode can be (TRAIN, EVAL, PREDICT).\n\n      Returns:\n        A TPUEstimatorSpec which holds the training_op that TPUEstimator will\n        run on TPU and the host_call that TPUEstimator will run on the host.\n      \"\"\"\n      del params\n      input_layer = tf_keras_v1.experimental.SequenceFeatures([embedding_column])\n      activations, sequence_lengths = input_layer(features)\n      opt = tf.tpu.CrossShardOptimizer(tf.train.GradientDescentOptimizer(0.1))\n      loss = tf.reduce_sum(activations)\n      train_op = opt.minimize(loss, global_step=tf.train.get_global_step())\n\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode,\n          loss=loss,\n          train_op=train_op,\n          host_call=(_host_call, [activations, sequence_lengths]),\n      )\n\n    tpu_config = tpu_config_lib.TPUConfig(\n        per_host_input_for_training=(\n            tpu_config_lib.InputPipelineConfig.PER_HOST_V2),)\n    run_config = tpu_config_lib.RunConfig(\n        session_config=tf.ConfigProto(isolate_session_state=True),\n        tpu_config=tpu_config,\n    )\n    estimator = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        model_dir=self._model_dir,\n        use_tpu=True,\n        train_batch_size=batch_size,\n        eval_batch_size=batch_size,\n        config=run_config,\n        embedding_config_spec=embedding_config_spec,\n    )\n\n    # Train for 1 step and store the activations as summaries.\n    estimator.train(_input_fn, steps=1)\n\n    # Read the event summaries and decode the activation tensors.\n    output = {}\n    for filename in tf.io.gfile.listdir(self._summary_dir):\n      filepath = os.path.join(os.path.join(self._summary_dir, filename))\n      for event in tf.train.summary_iterator(filepath):\n        for v in event.summary.value:\n          decoded = tf.io.decode_raw(v.tensor.tensor_content, v.tensor.dtype)\n          shape = tf.TensorShape(v.tensor.tensor_shape)\n          output[v.tag] = tf.reshape(decoded, shape)\n    return (output[self._SUMMARY_ACTIVATIONS],\n            output[self._SUMMARY_SEQUENCE_LENGTHS])\n\n  def test_non_contiguous_sequence(self):\n    \"\"\"Tests embedding lookups for non-contiguous sparse IDs.\n\n    A \"non-contiguous sequence\" is a sequence which has missing values followed\n    by actual values.\n    \"\"\"\n    batch_size = 4\n    max_sequence_length = 3\n    dimension = 2\n    embedding_weights = np.float32([\n        [-5., -5.],  # embedding ID = 0\n        [10., 11.],  # embedding ID = 1\n        [20., 21.],  # embedding ID = 2\n        [30., 31.],  # embedding ID = 3\n        [40., 41.],  # embedding ID = 4\n        [50., 51.],  # embedding ID = 5\n    ])\n\n    # The sparse_ids are indexes into the embedding_weights for each\n    # (example, sequence_index).\n    sparse_ids = tf.SparseTensorValue(\n        indices=[[0, 0], [1, 0], [1, 1], [2, 0], [2, 2]],\n        values=[\n            1,  # Example 0, sequence_index 0\n            2,  # Example 1, sequence_index 0\n            3,  # Example 1, sequence_index 1\n            4,  # Example 2, sequence_index 0\n            5,  # Example 2, sequence_index 2\n        ],\n        dense_shape=[batch_size, max_sequence_length],\n    )\n\n    activations, sequence_lengths = self.get_activations_and_sequence_lengths(\n        embedding_weights,\n        sparse_ids,\n        batch_size,\n        max_sequence_length,\n        dimension,\n    )\n\n    self.assertAllEqual(\n        [\n            [  # Example 0\n                [10, 11],  # Sequence Index = 0\n                [0., 0.],  # Sequence Index = 1\n                [0., 0.],  # Sequence Index = 2\n            ],\n            [  # Example 1\n                [20, 21],  # Sequence Index = 0\n                [30, 31],  # Sequence Index = 1\n                [0., 0.],  # Sequence Index = 2\n            ],\n            [  # Example 2\n                [40, 41],  # Sequence Index = 0\n                [0., 0.],  # Sequence Index = 1 (Missing value mid-sequence)\n                [50, 51],  # Sequence Index = 2\n            ],\n            [  # Example 3\n                [0., 0.],  # Sequence Index = 0\n                [0., 0.],  # Sequence Index = 1\n                [0., 0.],  # Sequence Index = 2\n            ],\n        ],\n        activations)\n    self.assertAllEqual(\n        [\n            1,  # Example 0\n            2,  # Example 1\n            3,  # Example 2\n            0,  # Example 3\n        ],\n        sequence_lengths,\n    )\n\n  def test_non_contiguous_sequence_with_length_gt_max_sequence_length(self):\n    \"\"\"Tests non contiguous sequence which has length > max_sequence_length.\n\n    A \"non-contiguous sequence\" is a sequence which has missing values followed\n    by actual values.\n\n    Additionally, this test has a sequence with length > max_sequence_length. In\n    this case, we expect the sequence to be truncated from the right.\n    \"\"\"\n    batch_size = 4\n    max_sequence_length = 3\n    dimension = 1\n    embedding_weights = np.float32([\n        [-5.],  # embedding ID = 0\n        [10.],  # embedding ID = 1\n        [20.],  # embedding ID = 2\n        [30.],  # embedding ID = 3\n        [40.],  # embedding ID = 4\n        [50.],  # embedding ID = 5\n    ])\n\n    # The sparse_ids are indexes into the embedding_weights for each\n    # (example, sequence_index).  Sequence indexes larger than max_sequence\n    # length will be truncated.\n    sparse_ids = tf.SparseTensorValue(\n        indices=[[0, 0], [1, 0], [1, 1], [2, 0], [2, 2], [2, 3]],\n        values=[\n            1,  # Example 0, sequence_index 0\n            2,  # Example 1, sequence_index 0\n            3,  # Example 1, sequence_index 1\n            4,  # Example 2, sequence_index 0\n            5,  # Example 2, sequence_index 2\n            6,  # Example 2, sequence_index 3\n        ],\n        dense_shape=[batch_size, max_sequence_length + 1],\n    )\n\n    activations, sequence_lengths = self.get_activations_and_sequence_lengths(\n        embedding_weights,\n        sparse_ids,\n        batch_size,\n        max_sequence_length,\n        dimension,\n    )\n\n    self.assertAllEqual(\n        [\n            [  # Example 0\n                [10],  # Sequence Index = 0\n                [0.],  # Sequence Index = 1\n                [0.],  # Sequence Index = 2\n            ],\n            [  # Example 1\n                [20],  # Sequence Index = 0\n                [30],  # Sequence Index = 1\n                [0.],  # Sequence Index = 2\n            ],\n            [  # Example 2 (Truncated)\n                [40],  # Sequence Index = 0\n                [0.],  # Sequence Index = 1 (Missing value mid-sequence)\n                [50],  # Sequence Index = 2\n            ],\n            [  # Example 3\n                [0.],  # Sequence Index = 0\n                [0.],  # Sequence Index = 1\n                [0.],  # Sequence Index = 2\n            ],\n        ],\n        activations)\n\n    self.assertAllEqual(\n        [\n            1,  # Example 0\n            2,  # Example 1\n            3,  # Example 2\n            0,  # Example 3\n        ],\n        sequence_lengths,\n    )\n\n  @parameterized.named_parameters(\n      ('sum_combiner', 'sum'),\n      ('mean_combiner', 'mean'),\n  )\n  def test_multivalent_sequence_features(self, combiner: Text):\n    \"\"\"Tests multivalent sequence embedding features.\n\n    Args:\n      combiner: The combiner used to reduce multivalent features.  A multivalent\n        sequence can have many IDs per sequence index.  The input for\n        multivalent sequence features is a 3D SparseTensor (instead of a 2D\n        SparseTensor for univalent sequence features).  The last dimension\n        represents the index that will be reduced (using the combiner).\n    \"\"\"\n    batch_size = 4\n    max_sequence_length = 3\n    dimension = 1\n    embedding_weights = np.float32([\n        [-5.],  # embedding ID = 0\n        [10.],  # embedding ID = 1\n        [20.],  # embedding ID = 2\n        [30.],  # embedding ID = 3\n        [40.],  # embedding ID = 4\n        [50.],  # embedding ID = 5\n    ])\n\n    # For multivalent sequence features, IDs are a 3D sparse tensor.\n    # The outer dimension is batch, the middle dimension is sequence, and the\n    # last dimension is the index.\n    sparse_ids = tf.SparseTensorValue(\n        indices=[\n            [0, 0, 0],\n            [0, 0, 1],\n            [1, 0, 0],\n            [1, 1, 0],\n            [3, 0, 0],\n            [3, 2, 0],\n            [3, 2, 1],\n            [3, 3, 0],\n        ],\n        values=[\n            1,  # Example 0, sequence_index 0,  id_index 0.\n            0,  # Example 0, sequence_index 0,  id_index 1.\n            2,  # Example 1, sequence_index 0,  id_index 0.\n            3,  # Example 1, sequence_index 1,  id_index 0.\n            4,  # Example 3, sequence_index 0,  id_index 0.\n            5,  # Example 3, sequence_index 2.  id_index 0.\n            2,  # Example 3, sequence_index 2.  id_index 1.\n            5,  # Example 3, sequence_index 3,  id_index 0.\n        ],\n        dense_shape=[batch_size, max_sequence_length + 1, 2],\n    )\n\n    activations, sequence_lengths = self.get_activations_and_sequence_lengths(\n        embedding_weights,\n        sparse_ids,\n        batch_size,\n        max_sequence_length,\n        dimension,\n        combiner=combiner,\n    )\n\n    self.assertAllEqual(\n        [\n            [  # Example 0\n                [5 if combiner == 'sum' else 2.5],  # Sequence Index = 0.\n                [0.],  # Sequence Index = 1.\n                [0.],  # Sequence Index = 2.\n            ],\n            [  # Example 1\n                [20],  # Sequence Index = 0.\n                [30],  # Sequence Index = 1.\n                [0.],  # Sequence Index = 2.\n            ],\n            [  # Example 2\n                [0.],  # Sequence Index = 0.\n                [0.],  # Sequence Index = 1.\n                [0.],  # Sequence Index = 2.\n            ],\n            [  # Example 3\n                [40],  # Sequence Index = 0.\n                [0.],  # Sequence Index = 1.\n                [70 if combiner == 'sum' else 35],  # Sequence Index = 2.\n            ],\n        ],\n        activations,\n    )\n\n    self.assertAllEqual(\n        [\n            1,  # Example 0\n            2,  # Example 1\n            0,  # Example 2\n            3,  # Example 3\n        ],\n        sequence_lengths,\n    )\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ===================================================================\n\"\"\"TPUEstimator class.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport copy\nimport enum\nimport math\nimport os\nimport signal\nimport sys\nimport threading\nimport time\n\nimport tensorflow as tf\nimport numpy as np\nimport six\nfrom six.moves import queue as Queue  # pylint: disable=redefined-builtin\nfrom six.moves import xrange  # pylint: disable=redefined-builtin\n\nfrom tensorflow.core.framework import variable_pb2\nfrom tensorflow.core.framework.summary_pb2 import Summary\nfrom tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result\nfrom tensorflow.python.data.util import nest as data_nest\nfrom tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver\nfrom tensorflow.python.framework import function\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import control_flow_ops\nfrom tensorflow.python.ops import control_flow_util\nfrom tensorflow.python.ops import ref_variable\nfrom tensorflow.python.ops import summary_ops_v2\nfrom tensorflow.python.ops import variable_scope\nfrom tensorflow.python.platform import tf_logging as logging\nfrom tensorflow.python.tpu import functional as tpu_functional\nfrom tensorflow.python.tpu import preempted_hook\nfrom tensorflow.python.tpu import session_support\nfrom tensorflow.python.tpu import tensor_tracer\nfrom tensorflow.python.tpu import tpu\nfrom tensorflow.python.tpu import tpu_embedding_gradient\nfrom tensorflow.python.tpu import tpu_feed\nfrom tensorflow.python.tpu import tpu_function\nfrom tensorflow.python.tpu import tpu_replication\nfrom tensorflow.python.tpu import training_loop\nfrom tensorflow.python.tpu.ops import tpu_ops\nfrom tensorflow.python.training import evaluation\nfrom tensorflow.python.util import function_utils\nfrom tensorflow.python.util import tf_inspect\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\nfrom tensorflow_estimator.python.estimator.export import export_output as export_output_lib\nfrom tensorflow_estimator.python.estimator.tpu import _tpu_estimator_embedding\nfrom tensorflow_estimator.python.estimator.tpu import error_handling\nfrom tensorflow_estimator.python.estimator.tpu import iteration_count_estimator\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\nfrom tensorflow_estimator.python.estimator.tpu import tpu_context\nfrom tensorflow_estimator.python.estimator.tpu import util as util_lib\nfrom tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding import AdagradParameters  # pylint: disable=unused-import\nfrom tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding import AdamParameters  # pylint: disable=unused-import\nfrom tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding import EmbeddingConfigSpec  # pylint: disable=unused-import\nfrom tensorflow_estimator.python.estimator.tpu._tpu_estimator_embedding import StochasticGradientDescentParameters  # pylint: disable=unused-import\n\n_INITIAL_LOSS = 1e7\n_ZERO_LOSS = 0.\n_TPU_ESTIMATOR = 'tpu_estimator'\n_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop'\n_BATCH_SIZE_KEY = 'batch_size'\n_CTX_KEY = 'context'\n_USE_TPU_KEY = 'use_tpu'\n_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum'\n_ONE_GIGABYTE = 1024 * 1024 * 1024\n_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops'\n_TPU_TRAIN_OP = '_tpu_train_op'\n_INFERENCE_ON_TPU_MODE = '_inference_on_tpu'\n_KEY_WHEN_PREDICTIONS_IS_A_TENSOR = '_key_when_predictions_is_a_tensor'\n_TENSOR_PACKER_SMALL_FEATURE_DIM_SIZE = 1\n_TENSOR_PACKER_MINIMUM_NUM_SMALL_FEATURES_TO_GROUP = 5\n_TENSOR_PACKER_CONCATENATED_SMALL_FEATURES_KEY = '_concatenated_small_features'\n\n# Ideally _USE_TPU_KEY should be reserved as well. However there are already\n# models that make use of this key, thus it can not be reserved now to prevent\n# breakage. In the long run, we would like to mitigate this by migrating models\n# off of using _USE_TPU_KEY.\n_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY]\n\n# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is\n# only used for per-core based deployments. For per-host based pipelines, if a\n# user returns a Dataset instance it will be automatically wrapped in a\n# tf.while_loop (This can be disabled by returning features and labels\n# explicitly).\n_WRAP_INPUT_FN_INTO_WHILE_LOOP = False\n\n# Track the adoption of TPUEstimator\n_tpu_estimator_gauge = tf.compat.v2.__internal__.monitoring.BoolGauge(\n    '/tensorflow/api/tpu_estimator',\n    'Whether the program uses tpu estimator or not.')\n\nif ops.get_to_proto_function('{}_{}'.format(_TPU_ESTIMATOR,\n                                            _ITERATIONS_PER_LOOP_VAR)) is None:\n  ops.register_proto_function(\n      '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR),\n      proto_type=variable_pb2.VariableDef,\n      to_proto=ref_variable._to_proto_fn,  # pylint: disable=protected-access\n      from_proto=ref_variable._from_proto_fn)  # pylint: disable=protected-access\n\n\ndef _is_iterable(obj):\n  \"\"\"A Python 2 and 3 compatible util to check whether `obj` is iterable.\"\"\"\n  try:\n    iter(obj)\n    return True\n  except TypeError:\n    return False\n\n\nclass CatchInvalidHostcallFunctions(control_flow_ops.XLAControlFlowContext):\n\n  def AddOp(self, op):\n    if op.type in [\n        'AudioSummary', 'AudioSummaryV2', 'HistogramSummary', 'ImageSummary',\n        'MergeSummary', 'ScalarSummary', 'TensorSummary', 'TensorSummaryV2'\n    ]:\n      raise ValueError('Please use tf.contrib.summary instead of tf.summary '\n                       'inside of host_calls.')\n\n\ndef _create_global_step(graph):\n  graph = graph or tf.compat.v1.get_default_graph()\n  if tf.compat.v1.train.get_global_step(graph) is not None:\n    raise ValueError('\"global_step\" already exists.')\n  # Create in proper graph and base name_scope.\n  with graph.as_default() as g, g.name_scope(None):\n    return tf.compat.v1.get_variable(\n        tf.compat.v1.GraphKeys.GLOBAL_STEP,\n        shape=[],\n        dtype=tf.dtypes.int64,\n        initializer=tf.compat.v1.initializers.zeros(),\n        trainable=False,\n        use_resource=True,\n        collections=[\n            tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,\n            tf.compat.v1.GraphKeys.GLOBAL_STEP\n        ])\n\n\ndef _create_or_get_iterations_per_loop():\n  \"\"\"Creates or gets the iterations_per_loop variable.\n\n  In TPUEstimator, the user provided computation, the model_fn, is wrapped\n  inside a tf.while_loop for peak performance. The iterations of the loop are\n  specified by this variable, which adjusts its value on the CPU after each TPU\n  program execution and before the next TPU execution.\n\n  The purpose of using a variable, rather then a constant, is to allow\n  TPUEstimator adapt the TPU training iterations according to the final steps\n  specified by users. For example, if the user sets the iterations_per_loop as 4\n  in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop\n  variable will have the following value before each TPU training.\n\n      - 1-th TPU execution: iterations_per_loop = 4\n      - 2-th TPU execution: iterations_per_loop = 4\n      - 3-th TPU execution: iterations_per_loop = 2\n\n  As model_fn increases the global step once per train_op invocation, the global\n  step is 10 after all TPU executions, matching the steps=10 inputs passed in by\n  users.\n\n  Returns:\n    A TF non-trainable resource variable.\n\n  Raises:\n    RuntimeError: If multi iterations_per_loop variables were found.\n  \"\"\"\n  graph = tf.compat.v1.get_default_graph()\n  collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR)\n  iter_vars = graph.get_collection(collection_name)\n  if len(iter_vars) == 1:\n    return iter_vars[0]\n  elif len(iter_vars) > 1:\n    raise RuntimeError('Multiple iterations_per_loop_var in collection.')\n\n  with ops.colocate_with(tf.compat.v1.train.get_global_step()):\n    with tf.compat.v1.variable_scope(\n        _TPU_ESTIMATOR, reuse=tf.compat.v1.AUTO_REUSE):\n      return tf.compat.v1.get_variable(\n          _ITERATIONS_PER_LOOP_VAR,\n          initializer=tf.compat.v1.initializers.zeros(),\n          shape=[],\n          dtype=tf.dtypes.int32,\n          trainable=False,\n          collections=[collection_name, tf.compat.v1.GraphKeys.LOCAL_VARIABLES],\n          use_resource=True)\n\n\ndef _sync_variables_ops(ctx):\n  \"\"\"Create varriables synchronization ops.\n\n  Gets the variables back from TPU nodes. This means the variables updated\n  by TPU will now be *synced* to host memory.\n  In BROADCAST mode, we skip this sync since the variables are ususally too\n  big to transmit via RPC.\n\n  Args:\n    ctx: A `_InternalTPUContext` instance with mode.\n\n  Returns:\n    A list of sync ops.\n  \"\"\"\n\n  if not ctx.is_input_broadcast_with_iterators():\n    return [\n        tf.debugging.check_numerics(v.read_value(),\n                                    'Gradient for %s is NaN' % v.name).op\n        for v in tf.compat.v1.trainable_variables()\n    ]\n  else:\n    return [tf.no_op()]\n\n\ndef _increase_eval_step_op(iterations_per_loop):\n  \"\"\"Returns an op to increase the eval step for TPU evaluation.\n\n  Args:\n    iterations_per_loop: Tensor. The number of eval steps running in TPU system\n      before returning to CPU host for each `Session.run`.\n\n  Returns:\n    An operation\n  \"\"\"\n  eval_step = evaluation._get_or_create_eval_step()  # pylint: disable=protected-access\n  # Estimator evaluate increases 1 by default. So, we increase the difference.\n  return tf.compat.v1.assign_add(\n      eval_step,\n      tf.cast(iterations_per_loop - 1, dtype=eval_step.dtype),\n      use_locking=True)\n\n\ndef _extract_key_names(tensor_or_dict):\n  if isinstance(tensor_or_dict, dict):\n    return sorted(tensor_or_dict.keys())\n  return []\n\n\nclass PeriodicLogger(object):\n\n  def __init__(self, seconds):\n    self._log_every_n_seconds = seconds\n    self._last_log_time = 0\n\n  def log(self, msg, *args, **kw):\n    if time.time() - self._last_log_time > self._log_every_n_seconds:\n      self._last_log_time = time.time()\n      tf.compat.v1.logging.info(msg, *args, **kw)\n\n\nclass _SIGNAL(object):\n  \"\"\"Signal used to control the thread of infeed/outfeed.\n\n  All preserved signals must be negative numbers. Positive numbers are used to\n  indicate the number of iterations for next training/evaluation loop.\n  \"\"\"\n  NEXT_BATCH = -1\n  STOP = -2\n\n\n@estimator_export(v1=['estimator.tpu.TPUEstimatorSpec'])\nclass TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access\n  \"\"\"Ops and objects returned from a `model_fn` and passed to `TPUEstimator`.\n\n  See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and\n  `export_outputs`.\n\n  For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where\n  `metric_fn` runs on CPU to generate metrics and `tensors` represents the\n  `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.\n  To be precise, TPU evaluation expects a slightly different signature from the\n  `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a\n  dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.\n  The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The\n  `tensors` usually specify the model logits, which are transferred back from\n  TPU system to CPU host. All tensors must have be batch-major, i.e., the batch\n  size is the first dimension. Once all tensors are available at CPU host from\n  all shards, they are concatenated (on CPU) and passed as positional arguments\n  to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is\n  a dict. `metric_fn` takes the `tensors` and returns a dict from metric string\n  name to the result of calling a metric function, namely a `(metric_tensor,\n  update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the\n  `eval_metrics`.\n\n  `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This\n  function should not capture any Tensors in `model_fn`.\n\n  `host_call` is a tuple of a `function` and a list or dictionary of `tensors`\n  to pass to that function and returns a list of Tensors. `host_call` currently\n  works for train() and evaluate(). The Tensors returned by the function is\n  executed on the CPU on every step, so there is communication overhead when\n  sending tensors from TPU to CPU. To reduce the overhead, try reducing the\n  size of the tensors. The `tensors` are concatenated along their major (batch)\n  dimension, and so must be >= rank 1. The `host_call` is useful for writing\n  summaries with `tf.summary.create_file_writer`.\n\n  @compatibility(TF2)\n  TPU Estimator manages its own TensorFlow graph and session, so it is not\n  compatible with TF2 behaviors. We recommend that you migrate to the newer\n  `tf.distribute.TPUStrategy`. See the\n  [TPU guide](https://www.tensorflow.org/guide/tpu) for details.\n  @end_compatibility\n  \"\"\"\n\n  def __new__(cls,\n              mode,\n              predictions=None,\n              loss=None,\n              train_op=None,\n              eval_metrics=None,\n              export_outputs=None,\n              scaffold_fn=None,\n              host_call=None,\n              training_hooks=None,\n              evaluation_hooks=None,\n              prediction_hooks=None):\n    \"\"\"Creates a validated `TPUEstimatorSpec` instance.\"\"\"\n    cls._host_calls = {}\n    if eval_metrics is not None:\n      cls._host_calls['eval_metrics'] = eval_metrics\n    if host_call is not None:\n      cls._host_calls['host_call'] = host_call\n    _OutfeedHostCall.validate(cls._host_calls)\n\n    training_hooks = tuple(training_hooks or [])\n    evaluation_hooks = tuple(evaluation_hooks or [])\n    prediction_hooks = tuple(prediction_hooks or [])\n\n    for hook in training_hooks + evaluation_hooks + prediction_hooks:\n      if not isinstance(hook, tf.compat.v1.train.SessionRunHook):\n        raise TypeError(\n            'All hooks must be SessionRunHook instances, given: {}'.format(\n                hook))\n\n    return super(TPUEstimatorSpec, cls).__new__(\n        cls,\n        mode=mode,\n        predictions=predictions,\n        loss=loss,\n        train_op=train_op,\n        eval_metrics=eval_metrics,\n        export_outputs=export_outputs,\n        scaffold_fn=scaffold_fn,\n        host_call=host_call,\n        training_hooks=training_hooks,\n        evaluation_hooks=evaluation_hooks,\n        prediction_hooks=prediction_hooks)\n\n  def as_estimator_spec(self):\n    \"\"\"Creates an equivalent `EstimatorSpec` used by CPU train/eval.\"\"\"\n    host_call_ret = _OutfeedHostCall.create_cpu_hostcall(self._host_calls)\n    eval_metric_ops = None\n    if self.eval_metrics is not None:\n      eval_metric_ops = host_call_ret['eval_metrics']\n    hooks = None\n    if self.host_call is not None:\n      hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]\n    loss = self.loss\n    if tensor_tracer.TensorTracer.is_enabled() \\\n       and self.train_op is not None:\n      tt = tensor_tracer.TensorTracer()\n      loss = tt.trace_cpu(tf.compat.v1.get_default_graph(), loss, self.train_op)\n\n    hooks = tuple(hooks or [])\n    scaffold = self.scaffold_fn() if self.scaffold_fn else None\n    return model_fn_lib.EstimatorSpec(\n        mode=self.mode,\n        predictions=self.predictions,\n        loss=loss,\n        train_op=self.train_op,\n        eval_metric_ops=eval_metric_ops,\n        export_outputs=self.export_outputs,\n        scaffold=scaffold,\n        training_hooks=self.training_hooks + hooks,\n        evaluation_hooks=self.evaluation_hooks + hooks,\n        prediction_hooks=self.prediction_hooks + hooks)\n\n\nclass _OpQueueContext(object):\n  \"\"\"Manages work queue and thread for a infeed/outfeed thread.\"\"\"\n\n  def __init__(self, name, target, args):\n    self._name = name\n    self._queue = Queue.Queue()\n    args = (self,) + args\n    self._thread = threading.Thread(name=name, target=target, args=args)\n    self._thread.daemon = True\n    self._thread.start()\n\n  def stop(self):\n    self._queue.put(_SIGNAL.STOP)\n\n  def send_next_batch_signal(self, iterations):\n    self._queue.put(iterations)\n\n  def read_iteration_counts(self):\n    while True:\n      iterations = self._queue.get(block=True)\n      tf.compat.v1.logging.debug('%s read iterations %s', self._name,\n                                 iterations)\n      if iterations == _SIGNAL.STOP:\n        tf.compat.v1.logging.info('%s received shutdown signal, stopping.',\n                                  self._name)\n        return\n      yield iterations\n\n  def join(self):\n    tf.compat.v1.logging.info('Shutting down %s thread.', self._name)\n    self.stop()\n    self._thread.join()\n\n\nclass _OpSignalOnceQueueContext(_OpQueueContext):\n  \"\"\"Manages work queue and thread for a infeed/outfeed thread.\n\n  This subclass only signals once.\n  \"\"\"\n\n  def __init__(self, name, target, args):\n    super(_OpSignalOnceQueueContext, self).__init__(name, target, args)\n    self._has_signaled = False\n\n  def send_next_batch_signal(self, iterations):\n    if not self._has_signaled:\n      self._queue.put(iterations)\n      self._has_signaled = True\n\n\nclass TPUInfeedOutfeedSessionHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"A Session hook setting up the TPU initialization, infeed, and outfeed.\n\n  This hook does two major things:\n  1. initialize and shutdown TPU system.\n  2. launch and join the threads for infeed enqueue and (optional) outfeed\n     dequeue.\n  \"\"\"\n\n  def __init__(self,\n               ctx,\n               enqueue_ops,\n               dequeue_ops,\n               tpu_compile_op,\n               run_infeed_loop_on_coordinator=True,\n               rendezvous=None,\n               master=None,\n               session_config=None,\n               tpu_init_ops=None,\n               outfeed_every_n_steps=1):\n    self._master_job = ctx.master_job\n    self._enqueue_ops = enqueue_ops\n    self._dequeue_ops = dequeue_ops\n    self._rendezvous = rendezvous\n    self._master = master\n    self._session_config = session_config\n    self._init_ops = list(tpu_init_ops or [])\n    if ctx.embedding_config is None:\n      self._embedding_layer_config = None\n    else:\n      self._embedding_layer_config = (\n          ctx.embedding_config.tpu_embedding.config_proto)\n    self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator\n    self._initial_infeed_sleep_secs = (\n        ctx.config.tpu_config.initial_infeed_sleep_secs)\n    self._tpu_compile_op = tpu_compile_op\n\n    # When using model parallelism, the TPU is pre-initialized at startup to\n    # fetch mesh information. We skip re-initializing it here for\n    # MeshTensorFlow since it places variables on TPU directly. Reinitialize tpu\n    # is causing the variable corruption since the previous allocated memory\n    # might be overwritten for other purpose.\n    if (ctx.model_parallelism_enabled and\n        (ctx.config.tpu_config.per_host_input_for_training is\n         tpu_config.InputPipelineConfig.BROADCAST)):\n      self._should_initialize_tpu = False\n    else:\n      self._should_initialize_tpu = True\n    self._outfeed_every_n_steps = outfeed_every_n_steps\n\n  def begin(self):\n    tf.compat.v1.logging.info('TPU job name %s', self._master_job)\n    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()\n    if self._should_initialize_tpu:\n      self._finalize_ops = [\n          tf.compat.v1.tpu.shutdown_system(job=self._master_job)\n      ]\n    else:\n      self._finalize_ops = []\n\n    summary_writer_init_ops = summary_ops_v2.summary_writer_initializer_op()\n    self._init_ops.extend(summary_writer_init_ops)\n    # Get all the writer resources from the initializer, so we know what to\n    # flush.\n    for op in summary_writer_init_ops:\n      self._finalize_ops.append(\n        summary_ops_v2.legacy_raw_flush(writer=op.inputs[0]))\n\n  def _run_infeed(self, queue_ctx, session):\n    tf.compat.v1.logging.info('Starting infeed thread controller.')\n    if self._initial_infeed_sleep_secs:\n      tf.compat.v1.logging.info('Infeed thread sleeping for %d seconds.',\n                                self._initial_infeed_sleep_secs)\n      time.sleep(self._initial_infeed_sleep_secs)\n      tf.compat.v1.logging.info('Infeed thread starting after sleep')\n\n    with self._rendezvous.catch_errors(source='infeed', session=session):\n      if self._run_infeed_loop_on_coordinator:\n        for count, steps in enumerate(queue_ctx.read_iteration_counts()):\n          for i in xrange(steps):\n            tf.compat.v1.logging.debug('Infeed enqueue for iteration (%d, %d)',\n                                       count, i)\n            session.run(self._enqueue_ops)\n      else:\n        for _ in queue_ctx.read_iteration_counts():\n          session.run(self._enqueue_ops)\n      tf.compat.v1.logging.info('Infeed thread finished, shutting down.')\n\n  def _run_outfeed(self, queue_ctx, session):\n    tf.compat.v1.logging.info('Starting outfeed thread controller.')\n    status_logger = PeriodicLogger(seconds=60)\n    with self._rendezvous.catch_errors(source='outfeed', session=session):\n      for count, steps in enumerate(queue_ctx.read_iteration_counts()):\n        step_counter = 0\n        for i in xrange(steps):\n          tf.compat.v1.logging.debug('Outfeed dequeue for iteration (%d, %d)',\n                                     count, i)\n          if step_counter % self._outfeed_every_n_steps == 0:\n            session.run(self._dequeue_ops)\n          step_counter += 1\n          status_logger.log('Outfeed finished for iteration (%d, %d)', count, i)\n      tf.compat.v1.logging.info('Outfeed thread finished, shutting down.')\n\n  def _create_infeed_controller(self, name, target, args):\n    return _OpQueueContext(name=name, target=target, args=args)\n\n  def _assertCompilationSucceeded(self, result, coord):\n    proto = tpu_compilation_result.CompilationResultProto()\n    proto.ParseFromString(result)\n    if proto.status_error_message:\n      tf.compat.v1.logging.error('Compilation failed: {}'.format(\n          proto.status_error_message))\n      coord.request_stop()\n    else:\n      tf.compat.v1.logging.info('Compilation succeeded')\n\n  def after_create_session(self, session, coord):\n    if self._should_initialize_tpu:\n      tf.compat.v1.logging.info('Init TPU system')\n      start = time.time()\n      with tf.Graph().as_default():\n        with tf.compat.v1.Session(\n            self._master, config=self._session_config) as sess:\n          sess.run(\n              tf.compat.v1.tpu.initialize_system(\n                  job=self._master_job,\n                  embedding_config=self._embedding_layer_config))\n      tf.compat.v1.logging.info('Initialized TPU in %d seconds',\n                                time.time() - start)\n\n    session.run(\n        self._init_ops,\n        options=tf.compat.v1.RunOptions(timeout_in_ms=30 * 60 * 1000))\n\n    if os.environ.get('TPU_SPLIT_COMPILE_AND_EXECUTE', '') == '1':\n      tf.compat.v1.logging.info(\n          'Compiling user program: this may take a while...')\n      self._assertCompilationSucceeded(session.run(self._tpu_compile_op), coord)\n\n    self._infeed_controller = self._create_infeed_controller(\n        name='InfeedController', target=self._run_infeed, args=(session,))\n\n    self._outfeed_controller = _OpQueueContext(\n        name='OutfeedController', target=self._run_outfeed, args=(session,))\n\n    # Enable the worker watchdog to terminate workers on coordinator exit.\n    watchdog_timeout = int(os.environ.get('TF_TPU_WATCHDOG_TIMEOUT', '0'))\n    if watchdog_timeout > 0:\n      session_support.start_worker_watchdog(\n          session, shutdown_timeout=watchdog_timeout)\n\n  def before_run(self, run_context):\n    iterations = run_context.session.run(self._iterations_per_loop_var)\n\n    tf.compat.v1.logging.info('Enqueue next (%d) batch(es) of data to infeed.',\n                              iterations)\n    self._infeed_controller.send_next_batch_signal(iterations)\n\n    tf.compat.v1.logging.info(\n        'Dequeue next (%d) batch(es) of data from outfeed.', iterations)\n    self._outfeed_controller.send_next_batch_signal(iterations)\n\n  def end(self, session):\n    tf.compat.v1.logging.info('Stop infeed thread controller')\n    self._infeed_controller.join()\n    self._rendezvous.record_done('infeed')\n\n    tf.compat.v1.logging.info('Stop output thread controller')\n    self._outfeed_controller.join()\n    self._rendezvous.record_done('outfeed')\n\n    tf.compat.v1.logging.info('Shutdown TPU system.')\n    session.run(self._finalize_ops)\n\n\nclass TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook):\n\n  def __init__(self,\n               ctx,\n               enqueue_ops,\n               dequeue_ops,\n               tpu_compile_op,\n               rendezvous=None,\n               master=None,\n               session_config=None):\n    super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__(\n        ctx,\n        enqueue_ops,\n        dequeue_ops,\n        tpu_compile_op=tpu_compile_op,\n        run_infeed_loop_on_coordinator=False,\n        rendezvous=rendezvous,\n        master=master,\n        session_config=session_config)\n\n  def _create_infeed_controller(self, name, target, args):\n    return _OpSignalOnceQueueContext(name=name, target=target, args=args)\n\n\nclass _TPUStopAtStepHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Hook that requests stop at a specified step.\n\n  This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with\n  following differences for TPU training:\n\n  1. This hook sets the variable for `iterations_per_loop`, which is used by\n     `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed.\n     If the `iterations_per_loop` value is specified as time in seconds, the\n     number of iterations per `Session.run` will be estimated automatically\n     based on per iteration runtime.\n\n     As the hook execution order is not guaranteed, the variable update is\n     handled in `after_create_session` and `after_run` as\n     `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`.\n\n  2. For each training loop (session.run), the global step could be increased\n     multiple times on TPU. The global step tensor value will be explicitly read\n     again in `after_run` to ensure the latest value is retrieved to avoid race\n     condition.\n  \"\"\"\n\n  def __init__(self,\n               iterations_per_loop_counter,\n               num_steps=None,\n               final_step=None):\n    \"\"\"Initializes a `TPUStopAtStepHook`.\n\n    Args:\n      iterations_per_loop_counter: A namedtuple of [`value',`unit`] that\n        represents the number of 'iterations count' or 'time in seconds' to run\n        optimizer per loop, based on the `unit` specified, `count` or `seconds`\n        respectively.\n      num_steps: Number of steps to execute.\n      final_step: Step after which to stop.\n\n    Raises:\n      ValueError: If one of the arguments is invalid.\n    \"\"\"\n    if num_steps is None and final_step is None:\n      raise ValueError('One of `num_steps` or `final_step` must be specified.')\n    if num_steps is not None and final_step is not None:\n      raise ValueError(\n          'Only one of `num_steps` or `final_step` can be specified.')\n    self._iterations_per_loop_counter = iterations_per_loop_counter\n    if self._iterations_per_loop_counter.unit not in ['seconds', 'count']:\n      raise ValueError('Only `count` or `seconds` are accepted as the '\n                       '`iterations_per_loop_counter.unit')\n    self._num_steps = num_steps\n    self._final_step = final_step\n    self._next_iteration_count = 1\n    self._iteration_count_estimator = None\n    if self._iterations_per_loop_counter.unit == 'seconds':\n      self._iteration_count_estimator = (\n          iteration_count_estimator.IterationCountEstimator())\n    self._start_time = time.time()\n\n  def _next_iterations(self, global_step, final_step):\n    \"\"\"Computes the next iterations count.\n\n    The next iterations count is computed by choosing the smaller of the\n    remaining step count (`final_step` - `global_step`) and the estimated\n    iterations count returned by the estimator.\n\n    Args:\n      global_step: The current step.\n      final_step: Step after which to stop.\n\n    Returns:\n      The number of iterations count to run per loop.\n    \"\"\"\n    remaining_steps = final_step - global_step\n\n    if self._iteration_count_estimator is not None:\n      estimated_iterations = self._iteration_count_estimator.get(\n          self._iterations_per_loop_counter.value)\n    else:\n      estimated_iterations = self._iterations_per_loop_counter.value\n\n    self._next_iteration_count = min(remaining_steps, estimated_iterations)\n    return self._next_iteration_count\n\n  def begin(self):\n    \"\"\"Initializes variables.\n\n    Initializes the global step and iterations per loop variables.\n\n    Raises:\n      RuntimeError: An error occurred if global step variable does not exist.\n    \"\"\"\n    self._global_step_tensor = tf.compat.v1.train.get_global_step()\n    if self._global_step_tensor is None:\n      raise RuntimeError('Global step should be created.')\n\n    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()\n\n  def after_create_session(self, session, coord):\n    \"\"\"Computes and updates the first time iterations count.\n\n    The iterations are computed by choosing the smaller of the (`final step` -\n    `global step`), and the initial estimated iterations returned by the\n    estimator (by default is 1).\n\n    Args:\n      session: A TensorFlow Session that has been created.\n      coord: A Coordinator object which keeps track of all threads.\n    \"\"\"\n    global_step = session.run(self._global_step_tensor)\n    if self._final_step is None:\n      self._final_step = global_step + self._num_steps\n\n    iterations = self._next_iterations(global_step, self._final_step)\n    self._iterations_per_loop_var.load(iterations, session=session)\n\n  def before_run(self, run_context):\n    \"\"\"Reset the timer.\"\"\"\n    if self._iteration_count_estimator is not None:\n      self._start_time = time.time()\n\n  def after_run(self, run_context, run_values):\n    \"\"\"Computes the next iterations per loop value or terminates.\n\n    Computes the elapsed time to run the last optimizer loop and if the\n    `IterationCountEstimator` is used, records the elapsed time and iterations\n    count. If the final step count has been reached, terminates. Otherwise,\n    computes and updates the number of iterations to run the optimizer per loop.\n\n    Args:\n      run_context: A `SessionRunContext` object.\n      run_values: A SessionRunValues object.\n    \"\"\"\n    if self._iteration_count_estimator is not None:\n      elapsed_time = time.time() - self._start_time\n      tf.compat.v1.logging.info('ElapsedTime: %.3f', elapsed_time)\n      self._iteration_count_estimator.update(elapsed_time,\n                                             self._next_iteration_count)\n\n    # Global step cannot be retrieved via SessionRunArgs and before_run due to\n    # race condition.\n    global_step = run_context.session.run(self._global_step_tensor)\n    if global_step >= self._final_step:\n      run_context.request_stop()\n    else:\n      iterations = self._next_iterations(global_step, self._final_step)\n      self._iterations_per_loop_var.load(\n          iterations, session=run_context.session)\n\n\nclass _SetEvalIterationsHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Hook that requests stop at a specified step.\"\"\"\n\n  def __init__(self, num_steps):\n    \"\"\"Initializes a `_SetEvalIterationsHook`.\n\n    Args:\n      num_steps: Number of steps to execute.\n    \"\"\"\n    self._num_steps = num_steps\n\n  def begin(self):\n    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()\n\n  def after_create_session(self, session, coord):\n    self._iterations_per_loop_var.load(self._num_steps, session=session)\n\n\nclass _StoppingPredictHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Hook that requests stop according to the stopping signal in prediction.\"\"\"\n\n  def __init__(self, scalar_stopping_signal):\n    self._scalar_stopping_signal = scalar_stopping_signal\n\n  def begin(self):\n    self._iterations_per_loop_var = _create_or_get_iterations_per_loop()\n\n  def after_create_session(self, session, coord):\n    # This is not necessary as we do not run infeed enqueue and outfeed dequeue\n    # in side threads for prediction model. But it makes the\n    # TPUInfeedOutfeedSessionHook prints nice message.\n    self._iterations_per_loop_var.load(1, session=session)\n\n  def before_run(self, run_context):\n    return tf.compat.v1.train.SessionRunArgs(self._scalar_stopping_signal)\n\n  def after_run(self, run_context, run_values):\n    _ = run_context\n    scalar_stopping_signal = run_values.results\n    if _StopSignals.should_stop(scalar_stopping_signal):\n      # NOTE(xiejw): In prediction, stopping signals are inserted for each\n      # batch. And we append one more batch to signal the system it should stop.\n      # The data flow might look like\n      #\n      #  batch   0: images, labels, stop = 0  (user provided)\n      #  batch   1: images, labels, stop = 0  (user provided)\n      #  ...\n      #  batch  99: images, labels, stop = 0  (user provided)\n      #  batch 100: images, labels, stop = 1  (TPUEstimator appended)\n      #\n      # where the final batch (id = 100) is appended by TPUEstimator, so we\n      # should drop it before returning the predictions to user.\n      # To achieve that, we throw the OutOfRangeError in after_run. Once\n      # Monitored Session sees this error in SessionRunHook.after_run, the\n      # \"current\" prediction, i.e., batch with id=100, will be discarded\n      # immediately\n      raise tf.errors.OutOfRangeError(None, None, 'Stopped by stopping signal.')\n\n\ndef generate_per_core_enqueue_ops_fn_for_host(ctx, input_fn,\n                                              inputs_structure_recorder,\n                                              host_device, host_id):\n  \"\"\"Generates infeed enqueue ops for per-core input_fn on a single host.\"\"\"\n  captured_infeed_queue = _CapturedObject()\n  tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)\n\n  def enqueue_ops_fn():\n    \"\"\"A fn returns enqueue_ops.\"\"\"\n    num_cores_per_host = ctx.num_of_cores_per_host\n    per_host_sharded_inputs = []\n    for core_ordinal in range(num_cores_per_host):\n      with ops.name_scope('ordinal_%d' % (core_ordinal)):\n        user_context = tpu_context.TPUContext(\n            internal_ctx=ctx,\n            input_device=host_device,\n            invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal,\n            host_id=host_id)\n        inputs = _Inputs.from_input_fn(input_fn(user_context))\n        if inputs.is_dataset:\n          raise TypeError(\n              '`input_fn` returning `Dataset`  is not yet supported in '\n              'per-Core input pipeline deployment yet. Please set '\n              'TPUConfig.per_host_input_for_training to True or return '\n              '`features` and `labels` from `input_fn`')\n        features, labels = inputs.features_and_labels()\n\n        inputs_structure_recorder.validate_and_record_structure(\n            features, labels)\n        flattened_inputs = (\n            inputs_structure_recorder.flatten_features_and_labels(\n                features, labels))\n        per_host_sharded_inputs.append(flattened_inputs)\n\n    infeed_queue = tpu_feed.InfeedQueue(\n        number_of_tuple_elements=len(per_host_sharded_inputs[0]))\n    captured_infeed_queue.capture(infeed_queue)\n\n    per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(\n        per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl)\n    return per_host_enqueue_ops\n\n  return enqueue_ops_fn, captured_infeed_queue\n\n\ndef generate_per_host_enqueue_ops_fn_for_host(ctx, input_fn,\n                                              inputs_structure_recorder,\n                                              batch_axis, device, host_id):\n  \"\"\"Generates infeed enqueue ops for per-host input_fn on a single host.\"\"\"\n  captured_infeed_queue = _CapturedObject()\n\n  dataset_initializer = None\n\n  with tf.compat.v1.device(device):\n    user_context = tpu_context.TPUContext(\n        internal_ctx=ctx,\n        input_device=device,\n        invocation_index=host_id,\n        host_id=host_id)\n    inputs = _Inputs.from_input_fn(input_fn(user_context))\n\n    is_dataset = inputs.is_dataset\n    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:\n      if not is_dataset:\n        raise TypeError(\n            'For mode PREDICT, `input_fn` must return `Dataset` instead of '\n            '`features` and `labels`.')\n      if batch_axis is not None:\n        raise TypeError('For mode PREDICT, batch_axis is not supported yet.')\n      inputs = _InputsWithStoppingSignals(\n          dataset=inputs.dataset,\n          batch_size=ctx.batch_size_for_input_fn,\n          add_padding=True)\n\n    if is_dataset:\n      dataset_initializer = inputs.dataset_initializer()\n\n    tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)\n\n  def enqueue_ops_fn():\n    \"\"\"A Fn returning the TPU infeed enqueue ops.\n\n    By providing as a Fn, it can be invoked inside the tf.while_loop such that\n    the input pipeline for multiple iterations can be executed by one\n    Session.run call.\n\n    Returns:\n      list of dict of ops.\n    \"\"\"\n    with tf.compat.v1.device(device):\n      num_of_replicas_per_host = ctx.num_of_replicas_per_host\n      # Convert user input to features and labels.  If the user returns a\n      # dataset, it is initialized and the features and labels extracted via\n      # `dataset.iterator.get_next()`\n      features, labels = inputs.features_and_labels()\n      signals = inputs.signals()\n\n      features, labels, enqueue_datas_list = (\n          _tpu_estimator_embedding.split_inputs(\n              ctx,\n              features,\n              labels,\n              num_cores_per_batch=num_of_replicas_per_host))\n\n      inputs_structure_recorder.validate_and_record_structure(features, labels)\n      unsharded_tensor_list = (\n          inputs_structure_recorder.flatten_features_and_labels(\n              features, labels, signals))\n\n      infeed_queue = tpu_feed.InfeedQueue(\n          tuple_types=[t.dtype for t in unsharded_tensor_list],\n          tuple_shapes=[t.shape for t in unsharded_tensor_list],\n          shard_dimensions=batch_axis)\n      captured_infeed_queue.capture(infeed_queue)\n      infeed_queue.set_number_of_shards(num_of_replicas_per_host)\n      per_host_enqueue_ops = (\n          infeed_queue.split_inputs_and_generate_enqueue_ops(\n              unsharded_tensor_list,\n              placement_function=lambda x: device,\n              tpu_ordinal_function=tpu_ordinal_function_impl))\n\n      if ctx.embedding_config:\n        per_host_enqueue_ops.extend(\n            ctx.embedding_config.tpu_embedding.generate_enqueue_ops(\n                enqueue_datas_list))\n\n      if signals is None:\n        return per_host_enqueue_ops\n      else:\n        return {\n            'ops': per_host_enqueue_ops,\n            'signals': signals,\n        }\n\n  return enqueue_ops_fn, captured_infeed_queue, dataset_initializer\n\n\ndef generate_per_host_v2_enqueue_ops_fn_for_host(ctx, input_fn,\n                                                 inputs_structure_recorder,\n                                                 device, host_id,\n                                                 invocation_index):\n  \"\"\"Generates infeed enqueue ops for per-host input_fn on a single host.\"\"\"\n  captured_infeed_queue = _CapturedObject()\n  dataset_initializer = None\n\n  with tf.compat.v1.device(device):\n    user_context = tpu_context.TPUContext(\n        internal_ctx=ctx,\n        input_device=device,\n        invocation_index=invocation_index,\n        host_id=host_id)\n    inputs = _Inputs.from_input_fn(input_fn(user_context))\n\n    is_dataset = inputs.is_dataset\n    if not is_dataset:\n      raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 '\n                      'input pipeline configuration.')\n\n    # Be aware that when num_cores_per_replica > num_cores_per_host,\n    # ctx.num_of_replicas_per_host is 0.\n    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:\n      inputs = _InputsWithStoppingSignals(\n          dataset=inputs.dataset,\n          batch_size=ctx.batch_size_for_input_fn,\n          add_padding=True,\n          num_invocations_per_step=max(1, ctx.num_of_replicas_per_host))\n\n    dataset_initializer = inputs.dataset_initializer()\n\n    tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id)\n\n    def device_function_impl(shard_id):\n      if ctx.device_assignment is not None:\n        # Find the replica_id of the host's logical core 0.\n        # The current host_id is guaranteed to contain the logical core 0,\n        # even when num_cores_per_replica > num_cores_per_host -- the function\n        # caller makes sure that this host_id will must be receiving data (calls\n        # input_fn).\n        replica_id = ctx.device_assignment.lookup_replicas(\n            task_id=host_id, logical_core=0)[shard_id]\n        return ctx.tpu_host_placement_function(replica_id=replica_id)\n      else:\n        return None\n\n  def enqueue_ops_fn():\n    \"\"\"Generates the per_host enqueue ops.\"\"\"\n    control_deps = []\n    per_host_sharded_inputs = []\n    enqueue_datas_list = []\n    # Be aware that when num_cores_per_replica > num_cores_per_host,\n    # ctx.num_of_replicas_per_host is 0.\n    num_replicas_per_host = max(1, ctx.num_of_replicas_per_host)\n    cached_signals = None\n    with tf.compat.v1.device(device):\n      if not inputs.is_dataset:\n        raise TypeError('`input_fn` must return a `Dataset` for this mode.')\n      for host in range(num_replicas_per_host):\n        # Use control dependencies to ensure a deterministic ordering.\n        if ctx.allow_per_host_v2_parallel_get_next:\n          features, labels = inputs.features_and_labels()  # Calls get_next()\n        with tf.control_dependencies(control_deps):\n          if not ctx.allow_per_host_v2_parallel_get_next:\n            features, labels = inputs.features_and_labels()  # Calls get_next()\n          signals = inputs.signals()\n\n          # All the replicas share the replica 0's stopping signal.\n          # This avoids inconsistent state among different model replcias.\n          if cached_signals:\n            signals['stopping'] = cached_signals['stopping']\n          else:\n            cached_signals = signals\n\n        features, labels, enqueue_data = (\n            _tpu_estimator_embedding.split_inputs(ctx, features, labels))\n        if len(enqueue_data) != 1:\n          raise RuntimeError(('Missing or extra enqueue_data for host {}. '\n                              'len(enqueue_data) = {}.').format(\n                                 host, len(enqueue_data)))\n        enqueue_datas_list.append(enqueue_data[0])\n\n        inputs_structure_recorder.validate_and_record_structure(\n            features, labels)\n        flattened_inputs = (\n            inputs_structure_recorder.flatten_features_and_labels(\n                features, labels, signals))\n        control_deps.extend(flattened_inputs)\n        per_host_sharded_inputs.append(flattened_inputs)\n\n      if inputs_structure_recorder.flattened_input_dims:\n        input_partition_dims = inputs_structure_recorder.flattened_input_dims\n        if signals:\n          input_partition_dims += [None] * len(signals)\n        # pylint: disable=protected-access\n        infeed_queue = tpu_feed._PartitionedInfeedQueue(\n            number_of_tuple_elements=len(per_host_sharded_inputs[0]),\n            host_id=host_id,\n            input_partition_dims=input_partition_dims,\n            device_assignment=ctx.device_assignment)\n        per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(\n            per_host_sharded_inputs)\n      else:\n        infeed_queue = tpu_feed.InfeedQueue(\n            number_of_tuple_elements=len(per_host_sharded_inputs[0]))\n        per_host_enqueue_ops = infeed_queue.generate_enqueue_ops(\n            per_host_sharded_inputs,\n            tpu_ordinal_function=tpu_ordinal_function_impl,\n            placement_function=device_function_impl)\n\n      captured_infeed_queue.capture(infeed_queue)\n\n    if ctx.embedding_config:\n      per_host_enqueue_ops.extend(\n          ctx.embedding_config.tpu_embedding.generate_enqueue_ops(\n              enqueue_datas_list))\n\n    if signals is None:\n      return per_host_enqueue_ops\n    else:\n      return {\n          'ops': per_host_enqueue_ops,\n          'signals': signals,\n      }\n\n  return enqueue_ops_fn, captured_infeed_queue, dataset_initializer\n\n\ndef generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,\n                                      num_hosts):\n  \"\"\"Generates infeed enqueue ops for one input_fn on all the hosts.\"\"\"\n  captured_infeed_queue = _CapturedObject()\n  dataset_initializer = None\n  device_0 = ctx.tpu_host_placement_function(host_id=0)\n  with tf.compat.v1.device(device_0):\n    user_context = tpu_context.TPUContext(\n        internal_ctx=ctx, input_device=device_0, invocation_index=0, host_id=0)\n    inputs = _Inputs.from_input_fn(input_fn(user_context))\n\n    is_dataset = inputs.is_dataset\n    if ctx.mode == model_fn_lib.ModeKeys.PREDICT:\n      if not is_dataset:\n        raise TypeError(\n            'For mode PREDICT, `input_fn` must return `Dataset` instead of '\n            '`features` and `labels`.')\n\n      inputs = _InputsWithStoppingSignals(\n          dataset=inputs.dataset,\n          batch_size=ctx.batch_size_for_input_fn,\n          add_padding=True)\n\n    if is_dataset:\n      dataset_initializer = inputs.dataset_initializer()\n    num_replicas_per_host = ctx.num_of_replicas_per_host\n\n  def tpu_ordinal_function_impl(shard_id):\n    if ctx.device_assignment:\n      return ctx.device_assignment.tpu_ordinal(replica=shard_id)\n    else:\n      return shard_id % num_replicas_per_host\n\n  def device_function_impl(shard_id):\n    # shard_id ranges from 0 to num_of_replicas_per_host - 1.\n    # A shard is a replica inside a host.\n    # In broadcast mode (generate_broadcast_enqueue_ops_fn), the enqueue ops\n    # are always executed on the first host. Thus shard_id equals to replica_id.\n    return ctx.tpu_host_placement_function(replica_id=shard_id)\n\n  def enqueue_ops_fn():\n    \"\"\"Generates enqueue ops for all the hosts.\"\"\"\n    broadcasted_inputs = []\n    flattened_inputs = None  # Cache result from input_fn.\n    signals = None\n    num_replicas = ctx.num_replicas\n    core_id = 0\n    for host_id in xrange(num_hosts):\n      with tf.compat.v1.device(\n          ctx.tpu_host_placement_function(host_id=host_id)):\n        for _ in xrange(ctx.num_of_replicas_per_host):\n          # Note: input_fn is only called once at host 0 for the first replica.\n          # The features and labels returned from that invocation are\n          # broadcasted to other replicas(including the replicas on other\n          # hosts).\n          if flattened_inputs is None:\n            features, labels = inputs.features_and_labels()  # Calls get_next()\n            signals = inputs.signals()\n\n            inputs_structure_recorder.validate_and_record_structure(\n                features, labels)\n            flattened_inputs = (\n                inputs_structure_recorder.flatten_features_and_labels(\n                    features, labels, signals))\n            if (ctx.config.tpu_config.eval_training_input_configuration is\n                tpu_config.InputPipelineConfig.SLICED):\n              input_slices = [\n                  tf.split(x, num_replicas) for x in flattened_inputs\n              ]\n          if (ctx.config.tpu_config.eval_training_input_configuration is\n              tpu_config.InputPipelineConfig.SLICED):\n            # for each core, slice out the flattened_inputs for each core.\n            broadcasted_inputs.append([x[core_id] for x in input_slices])\n            core_id += 1\n          else:\n            broadcasted_inputs.append(flattened_inputs)\n\n    infeed_queue = tpu_feed.InfeedQueue(\n        number_of_tuple_elements=len(broadcasted_inputs[0]))\n    captured_infeed_queue.capture(infeed_queue)\n    enqueue_ops = infeed_queue.generate_enqueue_ops(\n        broadcasted_inputs,\n        tpu_ordinal_function=tpu_ordinal_function_impl,\n        placement_function=device_function_impl)\n\n    if signals is None:\n      return enqueue_ops\n    else:\n      return {\n          'ops': enqueue_ops,\n          'signals': signals,\n      }\n\n  return enqueue_ops_fn, captured_infeed_queue, dataset_initializer\n\n\nclass TensorPacker(object):\n  \"\"\"Pack and unpack small tensors into a big one for efficiency.\"\"\"\n\n  def __init__(self, small_feature_dim_size,\n               minimum_num_small_features_to_group):\n    self._small_feature_dim_size = small_feature_dim_size\n    self._minimum_num_small_features_to_group = (\n        minimum_num_small_features_to_group)\n\n  def maybe_concatenate_features(self, features):\n    \"\"\"If there are enough small tensors, concat them for performance.\"\"\"\n    self._small_feature_names = {}\n    self._small_feature_sizes = {}\n    feature_names = _extract_key_names(features)\n    if feature_names:  # Not a single tensor.\n      # First pass: see if it is worth concatenating the small features.\n      for name in feature_names:\n        tensor = features[name]\n        # We do not handle nested inputs here.\n        if not isinstance(tensor, tf.Tensor):\n          return\n        shape = tensor.get_shape().as_list()\n        dtype = tensor.dtype\n        if (len(shape) == 2 and shape[1] is not None and\n            shape[1] <= self._small_feature_dim_size):\n          tf.compat.v1.logging.log_first_n(\n              tf.compat.v1.logging.INFO,\n              'Found small feature: %s %s', 1, name, shape)\n          if tensor.dtype not in self._small_feature_names:\n            self._small_feature_names[dtype] = []\n            self._small_feature_sizes[dtype] = []\n          self._small_feature_names[dtype].append(name)\n          self._small_feature_sizes[dtype].append(shape[1])\n\n      dtypes_ = list(self._small_feature_names.keys())\n      for dtype in dtypes_:\n        # If we could find 5 (or more) [batch_size, 1] dense features,\n        # we will group them.\n        if (len(self._small_feature_names[dtype]) <\n            self._minimum_num_small_features_to_group):\n          self._small_feature_names.pop(dtype)  # reset\n          self._small_feature_sizes.pop(dtype)  # reset\n\n      # Second pass: separate small features out\n      small_feature_tensors = {}\n      for dtype in self._small_feature_names:\n        small_feature_tensors[dtype] = []\n        for name in self._small_feature_names[dtype]:\n          small_feature_tensors[dtype].append(features.pop(name))\n\n      # Add the concat Tensor to features with a special key.\n      for dtype in self._small_feature_names:\n        key = self._get_small_feature_key(dtype)\n        if key in features:\n          raise ValueError('{} is reserved as feature key for concatenated'\n                           'small features.')\n        features[key] = (tf.concat(small_feature_tensors[dtype], axis=1))\n\n  def maybe_split_features(self, maybe_concatenated_features):\n    for dtype in self._small_feature_names:\n      key = self._get_small_feature_key(dtype)\n      concatenated_small_features = maybe_concatenated_features.pop(key)\n      splits = tf.split(\n          concatenated_small_features, self._small_feature_sizes[dtype], axis=1)\n      for name, split in zip(self._small_feature_names[dtype], splits):\n        maybe_concatenated_features[name] = split\n\n  def _get_small_feature_key(self, dtype):\n    return _TENSOR_PACKER_CONCATENATED_SMALL_FEATURES_KEY + '_' + str(dtype)\n\n\nclass _InputPipeline(object):\n  \"\"\"`_InputPipeline` handles invoking `input_fn` and piping to infeed queue.\n\n  `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from\n  call site.  To be precise, based on the configuration in\n  `_InternalTPUContext`,  it invokes `input_fn` for all cores (usually\n  multi-host TPU training) or for one host (usually for single-host TPU\n  evaluation), and sends all `features` and `labels` returned by `input_fn` to\n  TPU infeed. For per-core invocation, `features` and `labels` are piped to\n  infeed directly, one tuple for each core. For per-host invocation,  `features`\n  and `labels` are split at host (with respect to `batch_axis`) and piped to all\n  cores accordingly.\n\n  In addition, flatten/unflatten are handled by `_InputPipeline` also.  Model\n  inputs returned by the `input_fn` can have one of the following forms:\n  1. features\n  2. (features, labels)\n  3. ((arbitrarily nested structure of features), labels)\n\n  Internally, form 1 is reformed to `(features, None)` as features and labels\n  are passed separately to underlying methods. For TPU training, TPUEstimator\n  may expect multiple `features` and `labels` tuples one for each core.\n\n  TPUEstimator allows various different structures for inputs (namely `features`\n  and `labels`).  Both `features` and `labels` can be any nested sturcture\n  supported by TF nest (namely, dict, tuples, namedtuples or any nested\n  structure of such of Tensors).  `labels` could be `None` as well.\n\n  These are flattened before they are passed to the infeed/outfeed library\n  as that expectes flattend lists.\n  \"\"\"\n\n  class InputsStructureRecorder(object):\n    \"\"\"The recorder to record inputs structure.\"\"\"\n\n    def __init__(self, input_partition_dims=None):\n      # Holds the structure of inputs\n      self._feature_structure = {}\n      self._flattened_input_dims = None\n\n      if input_partition_dims:\n        # This should have been validated in TPUConfig.\n        assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.'\n        if len(input_partition_dims) == 2:\n          self._feature_dims, self._label_dims = input_partition_dims\n        else:\n          self._feature_dims = input_partition_dims[0]\n          self._label_dims = None\n\n        assert self._feature_dims is not None, ('input_partition_dims[0] must '\n                                                'not be None')\n      else:\n        self._feature_dims = None\n        self._label_dims = None\n\n      # Internal state.\n      self._initialized = False\n\n    @property\n    def flattened_input_dims(self):\n      assert self._initialized, 'InputsStructureRecorder is not initialized.'\n      return self._flattened_input_dims\n\n    def has_labels(self):\n      return 'labels' in self._feature_structure\n\n    def _flatten_input_dims(self, features, labels, feature_dims, label_dims):\n      \"\"\"Flatten input dims with the same order as flattened input tensors.\"\"\"\n\n      try:\n        flattened_input_dims = data_nest.flatten_up_to(features, feature_dims)\n      except TypeError as e:\n        raise ValueError(\n            'TPUConfig.input_partition_dims[0] mismatched the structure of'\n            ' features. input_partition_dims[0]: {}, features {}. {}'.format(\n                feature_dims, features, e))\n\n      if labels is not None:\n        if label_dims is not None:\n          try:\n            flattened_input_dims.extend(\n                data_nest.flatten_up_to(labels, self._label_dims))\n          except TypeError as e:\n            raise ValueError(\n                'TPUConfig.input_partition_dims[1] mismatched the structure of'\n                ' labels. input_partition_dims[1]: {}, labels: {}. {}'.format(\n                    label_dims, labels, e))\n        else:\n          num_label_tensors = len(data_nest.flatten(labels))\n          flattened_input_dims.extend([None] * num_label_tensors)\n      return flattened_input_dims\n\n    def validate_and_record_structure(self, features, labels):\n      \"\"\"Validates and records the structure of `features` and `labels`.\"\"\"\n      # Extract structure.\n      feature_names = _extract_key_names(features)\n      label_names = _extract_key_names(labels)\n\n      if not self._initialized:\n        # Record structure.\n        self._initialized = True\n        if self._feature_dims is not None:\n          feature_dims_names = _extract_key_names(self._feature_dims)\n          if feature_dims_names != feature_names:\n            raise ValueError(\n                'TPUConfig.input_partition_dims[0] mismatched feature'\n                ' keys. Expected {}, got {}'.format(feature_names,\n                                                    feature_dims_names))\n          label_dims_names = _extract_key_names(self._label_dims)\n          if self._label_dims is not None and label_dims_names != label_names:\n            raise ValueError(\n                'TPUConfig.input_partition_dims[1] mismatched label'\n                ' keys. Expected {}, got {}'.format(label_names,\n                                                    label_dims_names))\n          self._flattened_input_dims = self._flatten_input_dims(\n              features, labels, self._feature_dims, self._label_dims)\n\n    def flatten_features_and_labels(self, features, labels, signals=None):\n      \"\"\"Flattens the `features` and `labels` to a single tensor list.\"\"\"\n      self.tensor_packer = TensorPacker(\n          _TENSOR_PACKER_SMALL_FEATURE_DIM_SIZE,\n          _TENSOR_PACKER_MINIMUM_NUM_SMALL_FEATURES_TO_GROUP)\n      self.tensor_packer.maybe_concatenate_features(features)\n      self._feature_structure['features'] = features\n      if labels is not None:\n        self._feature_structure['labels'] = labels\n      if signals is not None:\n        self._feature_structure['signals'] = signals\n      return data_nest.flatten(self._feature_structure)\n\n    def unflatten_features_and_labels(self, flattened_inputs):\n      \"\"\"Restores the flattened inputs to original features and labels form.\n\n      Args:\n        flattened_inputs: Flattened inputs for each shard.\n\n      Returns:\n        A tuple of (`features`, `labels`), where `labels` could be None.\n        Each one, if present, should have identical structure (single tensor vs\n        dict) as the one returned by input_fn.\n\n      Raises:\n        ValueError: If the number of expected tensors from `flattened_inputs`\n          mismatches the recorded structure.\n      \"\"\"\n\n      unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,\n                                                      flattened_inputs)\n      features = unflattened_inputs['features']\n      self.tensor_packer.maybe_split_features(features)\n      return _Inputs(\n          features,\n          unflattened_inputs.get('labels'),\n          signals=unflattened_inputs.get('signals'))\n\n  def __init__(self, input_fn, batch_axis, ctx):\n    \"\"\"Constructor.\n\n    Args:\n      input_fn: input fn for train or eval.\n      batch_axis: A python tuple of int values describing how each tensor\n        produced by the Estimator `input_fn` should be split across the TPU\n        compute shards.\n      ctx: A `_InternalTPUContext` instance with mode.\n\n    Raises:\n      ValueError: If both `sharded_features` and `num_cores` are `None`.\n    \"\"\"\n    self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder(\n        ctx.input_partition_dims)\n\n    self._sharded_per_core = ctx.is_input_sharded_per_core()\n    self._input_fn = input_fn\n    self._infeed_queue = None\n    self._ctx = ctx\n    self._batch_axis = batch_axis\n\n  def generate_infeed_enqueue_ops_and_dequeue_fn(self):\n    \"\"\"Generates infeed enqueue ops and dequeue_fn.\"\"\"\n    # While tf.while_loop is called, the body function, which invokes\n    # `enqueue_fn` passed in, is called to construct the graph. So, input_fn\n    # structure is recorded.\n    enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = (\n        self._invoke_input_fn_and_record_structure())\n\n    self._validate_input_pipeline()\n\n    def dequeue_fn():\n      \"\"\"dequeue_fn is used by TPU to retrieve the tensors.\"\"\"\n      # In the model-parallel case, both the host-side and device-side\n      # computations must agree on the core on which infeed takes place. We\n      # choose to perform infeed on logical core 0 of each replica.\n      values = self._infeed_queue.generate_dequeue_op(tpu_device=0)\n      # The unflatten process uses the structure information recorded above.\n      return self._inputs_structure_recorder.unflatten_features_and_labels(\n          values)\n\n    return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator)\n\n  def _invoke_input_fn_and_record_structure(self):\n    \"\"\"Deploys the input pipeline and record input structure.\"\"\"\n    enqueue_ops = []\n    infeed_queues = []\n    all_dataset_initializers = []\n    num_hosts = self._ctx.num_hosts\n    tpu_host_placement_fn = self._ctx.tpu_host_placement_function\n\n    run_infeed_loop_on_coordinator = True\n\n    if self._sharded_per_core:\n      # Per-Core input pipeline deployment.\n      # Invoke input pipeline for each core and placed on the corresponding\n      # host.\n      for host_id in range(num_hosts):\n        host_device = tpu_host_placement_fn(host_id=host_id)\n        with tf.compat.v1.device(host_device):\n          with ops.name_scope('input_pipeline_task%d' % (host_id)):\n            enqueue_ops_fn, captured_infeed_queue = (\n                generate_per_core_enqueue_ops_fn_for_host(\n                    self._ctx, self._input_fn, self._inputs_structure_recorder,\n                    host_device, host_id))\n\n            if _WRAP_INPUT_FN_INTO_WHILE_LOOP:\n              run_infeed_loop_on_coordinator = False\n              enqueue_ops.append(\n                  _wrap_computation_in_while_loop(\n                      device=host_device, op_fn=enqueue_ops_fn))\n            else:\n              enqueue_ops.append(enqueue_ops_fn())\n            # Infeed_queue_getter must be called after enqueue_ops_fn is called.\n            infeed_queues.append(captured_infeed_queue.get())\n\n    elif self._ctx.is_input_broadcast_with_iterators():\n      # Only calls input_fn in host 0.\n      host_device = tpu_host_placement_fn(host_id=0)\n      enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (\n          generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn,\n                                            self._inputs_structure_recorder,\n                                            num_hosts))\n      if dataset_initializer:\n        all_dataset_initializers.append(dataset_initializer)\n        run_infeed_loop_on_coordinator = False\n        wrap_fn = (\n            _wrap_computation_in_while_loop\n            if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else\n            _wrap_computation_in_while_loop_with_stopping_signals)\n        enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn))\n      else:\n        enqueue_ops.append(enqueue_ops_fn())\n      infeed_queues.append(captured_infeed_queue.get())\n\n    else:\n      # This branch handles two senarios:\n      #       num_cores_per_replica > num_cores_per_host\n      #   and num_cores_per_replica <= num_cores_per_host\n      # First, get the set of host_ids, by iterating replicas.\n      # We only want and will get the set of *unique* host_ids\n      # *that will call input_fn*. For each replica, we only call the input_fn\n      # from the CPU host that contains logical core 0.\n\n      # Use a list here to ensure deterministic order.\n      host_id_with_invocation_id_pair = []\n\n      if not self._ctx.is_replica_across_hosts():\n        for host_id in range(num_hosts):\n          invocation_index = host_id\n          host_id_with_invocation_id_pair.append((host_id, invocation_index))\n      else:\n        for replica_id in xrange(self._ctx.num_replicas):\n          invocation_index = replica_id\n          host_device, _ = self._ctx.device_for_replica(replica_id)\n          # TODO(lehou): Get host_id in a better way.\n          host_id = int(host_device.split('/task:')[1].split('/device:')[0])\n          host_id_with_invocation_id_pair.append((host_id, invocation_index))\n\n      for (host_id, invocation_index) in host_id_with_invocation_id_pair:\n        host_device = tpu_host_placement_fn(host_id=host_id)\n        with tf.compat.v1.device(host_device):\n          with ops.name_scope('input_pipeline_task%d' % (host_id)):\n            if self._ctx.is_input_per_host_with_iterators():\n              enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (\n                  generate_per_host_v2_enqueue_ops_fn_for_host(\n                      self._ctx, self._input_fn,\n                      self._inputs_structure_recorder, host_device, host_id,\n                      invocation_index))\n            else:\n              enqueue_ops_fn, captured_infeed_queue, dataset_initializer = (\n                  generate_per_host_enqueue_ops_fn_for_host(\n                      self._ctx, self._input_fn,\n                      self._inputs_structure_recorder, self._batch_axis,\n                      host_device, host_id))\n\n            # NOTE(xiejw): We dispatch here based on the return type of the\n            # users `input_fn`.\n            #\n            # 1. If input_fn returns a Dataset instance, we initialize the\n            # iterator outside of tf.while_loop, and call the iterator.get_next\n            # inside tf.while_loop.  This should be always safe.\n            #\n            # 2. If input_fn returns (features, labels), it is too late to wrap\n            # them inside tf.while_loop, as resource initialization cannot be\n            # handled in TF control flow properly. In this case, we will use\n            # python loop to enqueue the data into TPU system.  This may be\n            # slow compared to the previous case.\n            if dataset_initializer:\n              all_dataset_initializers.append(dataset_initializer)\n              run_infeed_loop_on_coordinator = False\n              wrap_fn = (\n                  _wrap_computation_in_while_loop\n                  if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else\n                  _wrap_computation_in_while_loop_with_stopping_signals)\n              enqueue_ops.append(\n                  wrap_fn(device=host_device, op_fn=enqueue_ops_fn))\n            else:\n              enqueue_ops.append(enqueue_ops_fn())\n            infeed_queues.append(captured_infeed_queue.get())\n\n    # infeed_queue is used to generate dequeue ops. The only thing it uses for\n    # dequeue is dtypes and types. So, any one can be used. Here, grab the\n    # first one.\n    self._infeed_queue = infeed_queues[0]\n    return enqueue_ops, [\n        util_lib.MultiHostDatasetInitializerHook(all_dataset_initializers)\n    ], run_infeed_loop_on_coordinator\n\n  def _validate_input_pipeline(self):\n    \"\"\"Validates the input pipeline.\n\n    Perform some sanity checks to log user friendly information. We should\n    error out to give users better error message. But, if\n    _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break\n    user code, so, log a warning.\n\n    Raises:\n      RuntimeError: If the validation failed.\n    \"\"\"\n    if tf.compat.v1.get_default_graph().get_collection(\n        tf.compat.v1.GraphKeys.QUEUE_RUNNERS):\n      err_msg = ('Input pipeline contains one or more QueueRunners. '\n                 'It could be slow and not scalable. Please consider '\n                 'converting your input pipeline to use `tf.data` instead (see '\n                 'https://www.tensorflow.org/guide/datasets for '\n                 'instructions.')\n      if _WRAP_INPUT_FN_INTO_WHILE_LOOP:\n        raise RuntimeError(err_msg)\n      else:\n        logging.warn(err_msg)\n\n\ndef call_computation(computation_inputs, computation, batch_config=None):\n  \"\"\"Call computation.\n\n  Args:\n    computation_inputs: A tensor or dict of tensors, the inputs to the\n      computation.\n    computation: A Python function that takes no inputs and builds computation\n      graph. If `computation` returns m outputs, this function will return a\n      list of m Tensors.\n    batch_config: A BatchConfig named tuple specifying the batching\n      configuration to use for inference batching.\n\n  Returns:\n    A list of output tensors.\n  \"\"\"\n\n  # Using `TPUPartitionedCall` makes it possible to target a different\n  # TPU core with every `Session.run()` call. Note that the entire inference\n  # graph executes on a single core, and that invocations of this graph\n  # will round-robin among the cores attached to a host.\n  def tpu_partitioned_call(partition_inputs):\n\n    # capture_resource_var_by_value enables variables to be mirrored on TPU\n    # to avoid fetching from CPU, since variables do not change during\n    # inference.\n    @function.Defun(capture_resource_var_by_value=False)\n    def tpu_subgraph():\n      return computation(partition_inputs)\n\n    return tpu_functional.TPUPartitionedCall(\n        args=tpu_subgraph.captured_inputs,\n        device_ordinal=tpu_ops.tpu_ordinal_selector(),\n        Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg],\n        f=tpu_subgraph)\n\n  # Not using Batching Function but use TPUPartitionedCall/all cores.\n  if not batch_config:\n    return tpu_partitioned_call(computation_inputs)\n\n  # Use Batching Function and TPUPartitionedCall/all cores.\n  # Note that BatchingFunction requires a list of tensors and doesn't support\n  # a dict of tensors. So we preserve the structure by deterministically\n  # flattening the dict before batching and then recomposing it after batching\n  # to feed into the computation.\n  ordered_inputs_list = tf.nest.flatten(computation_inputs)\n\n  @tf.nondifferentiable_batch_function(\n      num_batch_threads=batch_config.num_batch_threads,\n      max_batch_size=batch_config.max_batch_size,\n      batch_timeout_micros=batch_config.batch_timeout_micros,\n      allowed_batch_sizes=batch_config.allowed_batch_sizes,\n      max_enqueued_batches=batch_config.max_enqueued_batches,\n      autograph=False)\n  def batched_tpu_computation(*tensor_args):\n    \"\"\"Recompose the input feature dict and calls the TPU computation.\"\"\"\n    computation_feature_input = tf.nest.pack_sequence_as(\n        computation_inputs, tensor_args)\n    return tpu_partitioned_call(computation_feature_input)\n\n  return batched_tpu_computation(*ordered_inputs_list)\n\n\nclass _ModelFnWrapper(object):\n  \"\"\"A `model_fn` wrapper.\n\n  This makes calling model_fn on CPU and TPU easier and more consistent and\n  performs necessary check and mutation required by TPU training and evaluation.\n\n  In addition, this wrapper manages converting the `model_fn` to a single TPU\n  train and eval step.\n  \"\"\"\n\n  def __init__(self, model_fn, config, params, ctx):\n    self._model_fn = model_fn\n    self._config = config\n    self._params = params\n    self._ctx = ctx\n\n  def call_without_tpu(self, features, labels, is_export_mode):\n    return self._call_model_fn(features, labels, is_export_mode=is_export_mode)\n\n  def _add_embedding_features(self, features, hook_dummy_table_variables):\n    \"\"\"Add embedding features, optionally add hook to intercept gradient.\"\"\"\n    if self._ctx.embedding_config:\n      tpu_embedding_ = self._ctx.embedding_config.tpu_embedding\n      embedding_activations = tpu_embedding_.get_activations()\n      if hook_dummy_table_variables:\n        new_embedding_activations = (\n            tpu_embedding_gradient.hook_dummy_table_variables_to_activations(\n                tpu_embedding_, embedding_activations,\n                self._ctx.embedding_config.dummy_table_variables))\n        features.update(new_embedding_activations)\n      else:\n        features.update(embedding_activations)\n\n  def convert_to_single_tpu_train_step(self, dequeue_fn):\n    \"\"\"Converts user provided model_fn` as a single train step on TPU.\n\n    The user provided `model_fn` takes input tuple\n    (features, labels) and produces the EstimatorSpec with train_op and loss for\n    train `mode`. This usually represents a single train computation on CPU.\n\n    For TPU training, a train (computation) step is first wrapped in a\n    tf.while_loop control flow to repeat for many times and then replicated to\n    all TPU shards. Besides the input should be taken from TPU infeed rather\n    than input pipeline (input_fn) directly. To fit TPU loop and replicate\n    pattern, the original train computation should be reformed, which is the\n    returned `train_step`.\n\n    Args:\n      dequeue_fn: The function to retrieve inputs, features and labels, from TPU\n        infeed dequeue channel.\n\n    Returns:\n      A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn\n      representing the train step for TPU.\n    \"\"\"\n\n    host_call = _OutfeedHostCall(\n        self._ctx,\n        outfeed_every_n_steps=self._config.tpu_config\n        .experimental_host_call_every_n_steps)\n    captured_scaffold_fn = _CapturedObject()\n    captured_training_hooks = _CapturedObject()\n\n    def train_step(step):\n      \"\"\"Training step function for use inside a while loop.\"\"\"\n      inputs = dequeue_fn()\n      features, labels = inputs.features_and_labels()\n      self._add_embedding_features(features, True)\n\n      estimator_spec = self._verify_estimator_spec(\n          self._call_model_fn(features, labels))\n      loss, train_op = estimator_spec.loss, estimator_spec.train_op\n\n      if tensor_tracer.TensorTracer.is_enabled():\n        tt = tensor_tracer.TensorTracer()\n        loss = tt.trace_tpu(tf.compat.v1.get_default_graph(), loss, train_op,\n                            self._ctx.num_replicas)\n        tracer_host_call = tt.host_call_deps_and_fn()\n      else:\n        tracer_host_call = {}\n\n      if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access\n        captured_scaffold_fn.capture(estimator_spec.scaffold_fn)\n      else:\n        captured_scaffold_fn.capture(None)\n\n      captured_training_hooks.capture(estimator_spec.training_hooks)\n\n      if self._ctx.embedding_config is None:\n        apply_sparse_grads = []\n      else:\n        tpu_embedding_ = self._ctx.embedding_config.tpu_embedding\n        gradients = (\n            tpu_embedding_gradient.get_gradients_through_dummy_table_variables(\n                tpu_embedding_))\n        grad_multiplier = self._ctx.embedding_config.get_grad_multiplier()\n        if grad_multiplier is not None:\n          scaled_gradients = collections.OrderedDict(\n              (k, v * grad_multiplier) for k, v in six.iteritems(gradients))\n        else:\n          scaled_gradients = gradients\n        apply_sparse_grads = [\n            tpu_embedding_.generate_send_gradients_op(\n                scaled_gradients, tf.compat.v1.train.get_global_step())\n        ]\n\n      stopping_signals = None\n      user_provided_stopping_signals_name = None\n      if self._ctx.feed_hook is not None:\n        stopping_signals, user_provided_stopping_signals_name = \\\n          self._ctx.feed_hook.get_stopping_signals_and_name(features)\n\n      # We must run train_op to update the variables prior to running the\n      # outfeed.\n      with tf.control_dependencies([train_op] + apply_sparse_grads):\n        host_call_outfeed_ops = []\n        host_call_fn, host_call_args = None, []\n\n        if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)  # pylint: disable=protected-access\n            and estimator_spec.host_call is not None):\n          host_call_fn, host_call_args = estimator_spec.host_call\n\n        if stopping_signals is not None:\n          identity_fn = lambda **kwargs: kwargs\n          tracer_host_call[user_provided_stopping_signals_name] = [\n              identity_fn, stopping_signals\n          ]\n\n        if host_call_fn:\n          # Ignore dummy hostcalls (no arguments)\n          if host_call_args:\n            tracer_host_call.update({'host_call': estimator_spec.host_call})\n            host_call.record(tracer_host_call)\n            host_call_outfeed_ops = host_call.create_enqueue_op(step)\n          elif tracer_host_call:\n            host_call.record(tracer_host_call)\n            host_call_outfeed_ops = host_call.create_enqueue_op(step)\n        else:\n          # Create a host call for the loss to track execution progress\n          # Without this, we don't have any indication of the state of the\n          # TPU program.\n          tracer_host_call.update(\n              {'host_call': (lambda loss_t: loss_t, [tf.reshape(loss, [1])])})\n          host_call.record(tracer_host_call)\n          host_call_outfeed_ops = host_call.create_enqueue_op(step)\n\n        with tf.control_dependencies(host_call_outfeed_ops):\n          return tf.identity(loss)\n\n    return (train_step, host_call, captured_scaffold_fn,\n            captured_training_hooks)\n\n  def convert_to_single_tpu_eval_step(self, dequeue_fn):\n    \"\"\"Converts user provided model_fn` as a single eval step on TPU.\n\n    Similar to training, the user provided `model_fn` takes input tuple\n    (features, labels) and produces the TPUEstimatorSpec with eval_metrics for\n    eval `mode`. This usually represents a single evaluation computation on CPU.\n\n    For TPU evaluation, a eval (computation) step is first wrapped in a\n    tf.while_loop control flow to repeat for many times and then replicated to\n    all TPU shards. Besides the input and output are slightly different. Input,\n    features and labels, should be taken from TPU infeed rather than input\n    pipeline (input_fn) directly. Output is managed in two stages.  First, the\n    model outputs as the result of evaluation computation, usually model logits,\n    should be transferred from TPU system to CPU. Then, all model outputs are\n    concatenated first on CPU and sent to the metric_fn for metrics computation.\n    To fit TPU evaluation pattern, the original eval computation should be\n    reformed, which is the returned `eval_step`.\n\n    Args:\n      dequeue_fn: The function to retrieve inputs, features and labels, from TPU\n        infeed dequeue channel.\n\n    Returns:\n      A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn\n      representing the eval step for TPU.\n    \"\"\"\n    host_calls = _OutfeedHostCall(self._ctx)\n    captured_scaffold_fn = _CapturedObject()\n    captured_eval_hooks = _CapturedObject()\n\n    def eval_step(total_loss):\n      \"\"\"Evaluation step function for use inside a while loop.\"\"\"\n      inputs = dequeue_fn()\n      features, labels = inputs.features_and_labels()\n      self._add_embedding_features(features, False)\n\n      tpu_estimator_spec = self._call_model_fn(features, labels)\n      if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access\n        raise RuntimeError(\n            'estimator_spec used by TPU evaluation must have type'\n            '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))\n\n      loss = tpu_estimator_spec.loss\n      captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)\n      captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks)\n\n      to_record = {}\n      if tpu_estimator_spec.eval_metrics:\n        to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics\n      if tpu_estimator_spec.host_call is not None:\n        # We assume that evaluate won't update global step, so we don't wrap\n        # this host_call.\n        to_record['host_call'] = tpu_estimator_spec.host_call\n      host_calls.record(to_record)\n\n      with tf.control_dependencies(host_calls.create_enqueue_op()):\n        return tf.math.add(total_loss, loss)\n\n    return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks\n\n  def convert_to_single_tpu_predict_step(self, dequeue_fn):\n    \"\"\"Converts user provided model_fn` as a single predict step on TPU.\n\n    Args:\n      dequeue_fn: The function to retrieve inputs, features and labels, from TPU\n        infeed dequeue channel.\n\n    Returns:\n      A tuple of predict_fn, host_calls, and captured scaffold_fn. The\n      predict_fn representing the predict step for TPU.\n    \"\"\"\n    host_calls = _OutfeedHostCall(self._ctx)\n    captured_scaffold_fn = _CapturedObject()\n    captured_predict_hooks = _CapturedObject()\n\n    def predict_step(unused_scalar_stopping_signal):\n      \"\"\"Evaluation step function for use inside a while loop.\"\"\"\n      inputs = dequeue_fn()\n      features, labels = inputs.features_and_labels()\n      stopping_signals = inputs.signals()\n\n      assert stopping_signals is not None, (\n          'Internal Error: `signals` is missing.')\n\n      tpu_estimator_spec = self._call_model_fn(\n          features, labels, is_export_mode=False)\n      if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access\n        raise RuntimeError(\n            'estimator_spec used by TPU prediction must have type'\n            '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec)))\n\n      self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions)\n\n      captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn)\n      captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks)\n      to_record = {}\n      identity_fn = lambda **kwargs: kwargs\n      to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions]\n      to_record['signals'] = [identity_fn, stopping_signals]\n      if tpu_estimator_spec.host_call is not None:\n        to_record['host_call'] = tpu_estimator_spec.host_call\n      host_calls.record(to_record)\n\n      with tf.control_dependencies(host_calls.create_enqueue_op()):\n        return _StopSignals.as_scalar_stopping_signal(stopping_signals)\n\n    return (predict_step, host_calls, captured_scaffold_fn,\n            captured_predict_hooks)\n\n  def _verify_tpu_spec_predictions(self, predictions):\n    \"\"\"Validates TPUEstimatorSpec.predictions dict.\"\"\"\n    # TODO(xiejw): Adds validation for prediction dictionrary.\n    # TODO(xiejw): Adds support for single tensor as predictions.\n    if not isinstance(predictions, dict):\n      raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.')\n\n    for (key, tensor) in predictions.items():\n      if tensor.shape.dims[0].value is None:\n        raise ValueError(\n            'The tensor with key ({}) in TPUEstimatorSpec.predictions has '\n            'dynamic shape (should be static). Tensor: {}'.format(key, tensor))\n    return predictions\n\n  def _validate_model_features_and_labels(self, features, labels,\n                                          is_export_mode):\n    \"\"\"Validates that the features and labels for the model function are valid.\n\n    A valid features/labels object is the one with:\n    - Type: A tensor or any nested structure of tensors supported by TF nest,\n        namely nested dictionary, tuple, namedtuple, or sequence of tensors.\n    - Static shape if is_export_mode is False.\n\n    Args:\n      features: the features that would be input to the model function.\n      labels: the labels that would be input to the model function.\n      is_export_mode: boolean value specifying if in export mode.\n\n    Raises:\n      TypeError: If features/labels are not of the correct type.\n      ValueError: If features/labels have dynamic shape.\n    \"\"\"\n\n    def validate(obj, obj_name):\n      \"\"\"Helper validate function.\"\"\"\n      if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode):\n        return\n      if isinstance(obj, tf.Tensor):\n        if not obj.get_shape().is_fully_defined():\n          raise ValueError(\n              'The {} to the model returned by input_fn must have static shape.'\n              ' Tensor: {}'.format(obj_name, obj))\n      else:\n        for tensor in data_nest.flatten(obj):\n          if not tensor.get_shape().is_fully_defined():\n            raise ValueError(\n                ('The {} to the model returned by input_fn must have static '\n                 'shape. Tensor: {}').format(obj_name, tensor))\n\n    validate(features, 'features')\n    if labels is not None:\n      validate(labels, 'labels')\n\n  def _call_model_fn(self, features, labels, is_export_mode=False):\n    \"\"\"Calls the model_fn with required parameters.\"\"\"\n    self._validate_model_features_and_labels(features, labels, is_export_mode)\n    model_fn_args = function_utils.fn_args(self._model_fn)\n    kwargs = {}\n\n    # Makes deep copy with `config` and params` in case user mutates them.\n    config = copy.deepcopy(self._config)\n    params = copy.deepcopy(self._params)\n\n    if 'labels' in model_fn_args:\n      kwargs['labels'] = labels\n    elif labels is not None:\n      raise ValueError(\n          'model_fn does not take labels, but input_fn returns labels.')\n    if 'mode' in model_fn_args:\n      kwargs['mode'] = self._ctx.mode\n    if 'config' in model_fn_args:\n      kwargs['config'] = config\n    if 'params' in model_fn_args:\n      kwargs['params'] = params\n\n    if 'params' not in model_fn_args:\n      raise ValueError('model_fn ({}) does not include params argument, '\n                       'required by TPUEstimator to pass batch size as '\n                       'params[\\'batch_size\\']'.format(self._model_fn))\n\n    if is_export_mode:\n      batch_size_for_model_fn = None\n    else:\n      batch_size_for_model_fn = self._ctx.batch_size_for_model_fn\n\n    if batch_size_for_model_fn is not None:\n      _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn)\n\n    running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode)\n    # In export mode, params['use_tpu'] has already been set based on mode\n    # (i.e. True for _REWRITE_FOR_INFERENCE_MODE, False otherwise).\n    if not is_export_mode:\n      _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu)\n\n    if not running_on_cpu:\n      user_context = tpu_context.TPUContext(\n          internal_ctx=self._ctx, call_from_input_fn=False)\n      _add_item_to_params(params, _CTX_KEY, user_context)\n\n    estimator_spec = self._model_fn(features=features, **kwargs)\n    if (running_on_cpu and\n        isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)):  # pylint: disable=protected-access\n      # The estimator_spec will be passed to `Estimator` directly, which expects\n      # type `EstimatorSpec`. As we are running on the CPU, escape\n      # the TPUInferenceContext.\n      graph_context = tf.compat.v1.get_default_graph(\n      )._get_control_flow_context()\n      try:\n        if isinstance(graph_context, tpu._TPUInferenceContext):\n          tf.compat.v1.get_default_graph()._set_control_flow_context(\n              graph_context.outer_context)\n        return estimator_spec.as_estimator_spec()\n      finally:\n        tf.compat.v1.get_default_graph()._set_control_flow_context(\n            graph_context)\n    else:\n      return estimator_spec\n\n  def _verify_estimator_spec(self, estimator_spec):\n    \"\"\"Validates the estimator_spec.\"\"\"\n    if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec):  # pylint: disable=protected-access\n      return estimator_spec\n\n    err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.'\n    if estimator_spec.training_chief_hooks:\n      raise ValueError(\n          err_msg.format('training_chief_hooks') + 'If you want' +\n          ' to pass training hooks, please pass via training_hooks.')\n\n    if estimator_spec.scaffold:\n      tf.compat.v1.logging.warn(\n          'EstimatorSpec.Scaffold is ignored by TPU train/eval. '\n          'Please use TPUEstimatorSpec.')\n    return estimator_spec\n\n\nclass _OutfeedHostCall(object):\n  \"\"\"Support for `eval_metrics` and `host_call` in TPUEstimatorSpec.\"\"\"\n\n  def __init__(self, ctx, outfeed_every_n_steps=1):\n    self._ctx = ctx\n    self._names = []\n    # All of these are dictionaries of lists keyed on the name.\n    self._host_fns = {}\n    self._tensor_keys = collections.defaultdict(list)\n    self._tensors = collections.defaultdict(list)\n    self._tensor_dtypes = collections.defaultdict(list)\n    self._tensor_shapes = collections.defaultdict(list)\n    self._outfeed_every_n_steps = outfeed_every_n_steps\n\n  @staticmethod\n  def validate(host_calls):\n    \"\"\"Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`.\"\"\"\n\n    for name, host_call in host_calls.items():\n      if not isinstance(host_call, (tuple, list)):\n        raise ValueError('{} should be tuple or list'.format(name))\n      if len(host_call) != 2:\n        raise ValueError('{} should have two elements.'.format(name))\n      if not callable(host_call[0]):\n        raise TypeError('{}[0] should be callable.'.format(name))\n      if not isinstance(host_call[1], (tuple, list, dict)):\n        raise ValueError('{}[1] should be tuple or list, or dict.'.format(name))\n\n      if isinstance(host_call[1], (tuple, list)):\n        fullargspec = tf_inspect.getfullargspec(host_call[0])\n        fn_args = function_utils.fn_args(host_call[0])\n        # wrapped_hostcall_with_global_step uses varargs, so we allow that.\n        if fullargspec.varargs is None and len(host_call[1]) != len(fn_args):\n          raise RuntimeError(\n              'In TPUEstimatorSpec.{}, length of tensors {} does not match '\n              'method args of the function, which takes {}.'.format(\n                  name, len(host_call[1]), len(fn_args)))\n\n  @staticmethod\n  def create_cpu_hostcall(host_calls):\n    \"\"\"Runs on the host_call on CPU instead of TPU when use_tpu=False.\"\"\"\n\n    _OutfeedHostCall.validate(host_calls)\n    ret = {}\n    for name, host_call in host_calls.items():\n      host_fn, tensors = host_call\n      if isinstance(tensors, (tuple, list)):\n        ret[name] = host_fn(*tensors)\n      else:\n        # Must be dict.\n        try:\n          ret[name] = host_fn(**tensors)\n        except TypeError as e:\n          tf.compat.v1.logging.warn(\n              'Exception while calling %s: %s. It is likely the tensors '\n              '(%s[1]) do not match the '\n              'function\\'s arguments', name, e, name)\n          raise\n    return ret\n\n  def record(self, host_calls):\n    \"\"\"Records the host_call structure.\"\"\"\n\n    for name, host_call in host_calls.items():\n      host_fn, tensor_list_or_dict = host_call\n      self._names.append(name)\n      self._host_fns[name] = host_fn\n\n      if isinstance(tensor_list_or_dict, dict):\n        for (key, tensor) in six.iteritems(tensor_list_or_dict):\n          self._tensor_keys[name].append(key)\n          self._tensors[name].append(tensor)\n          self._tensor_dtypes[name].append(tensor.dtype)\n          self._tensor_shapes[name].append(tensor.shape)\n      else:\n        # List or tuple.\n        self._tensor_keys[name] = None\n        for tensor in tensor_list_or_dict:\n          self._tensors[name].append(tensor)\n          self._tensor_dtypes[name].append(tensor.dtype)\n          self._tensor_shapes[name].append(tensor.shape)\n\n  def create_enqueue_op(self, step=None):\n    \"\"\"Create the op to enqueue the recorded host_calls.\n\n    Returns:\n      A list of enqueue ops, which is empty if there are no host calls.\n    \"\"\"\n    if not self._names:\n      return []\n\n    tensors = []\n    # TODO(jhseu): Consider deduping tensors.\n    for name in self._names:\n      tensors.extend(self._tensors[name])\n\n    if self._outfeed_every_n_steps > 1 and step is None:\n      raise ValueError('If outfeed is requested every n steps, you must pass '\n                       'a tensor whose value is the step number within the '\n                       'current training loop.')\n    with tf.compat.v1.device(tf.compat.v1.tpu.core(0)):\n      if self._outfeed_every_n_steps == 1:\n        return [tpu_ops.outfeed_enqueue_tuple(tensors)]\n      else:\n        return [\n            tf.compat.v1.cond(\n                tf.math.equal(\n                    tf.math.floormod(step, self._outfeed_every_n_steps),\n                    0), lambda: tpu_ops.outfeed_enqueue_tuple(tensors),\n                lambda: tf.no_op())\n        ]\n\n  def create_tpu_hostcall(self):\n    \"\"\"Sends the tensors through outfeed and runs the host_fn on CPU.\n\n    The tensors are concatenated along dimension 0 to form a global tensor\n    across all shards. The concatenated function is passed to the host_fn and\n    executed on the first host.\n\n    Returns:\n      A dictionary mapping name to the return type of the host_call by that\n      name.\n\n    Raises:\n      RuntimeError: If outfeed tensor is scalar.\n    \"\"\"\n    if not self._names:\n      return {}\n\n    ret = {}\n    # For each i, dequeue_ops[i] is a list containing the tensors from all\n    # shards. This list is concatenated later.\n    dequeue_ops = []\n    tensor_dtypes = []\n    tensor_shapes = []\n    for name in self._names:\n      for _ in self._tensors[name]:\n        dequeue_ops.append([])\n      for dtype in self._tensor_dtypes[name]:\n        tensor_dtypes.append(dtype)\n      for shape in self._tensor_shapes[name]:\n        tensor_shapes.append(shape)\n\n    # Outfeed ops execute on each replica's first logical core. Note: we must\n    # constraint it such that we have at most one outfeed dequeue and enqueue\n    # per replica.\n    for i in xrange(self._ctx.num_replicas):\n      host_device, ordinal_id = self._ctx.device_for_replica(i)\n      with tf.compat.v1.device(host_device):\n        outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(\n            dtypes=tensor_dtypes,\n            shapes=tensor_shapes,\n            device_ordinal=ordinal_id)\n        for j, item in enumerate(outfeed_tensors):\n          dequeue_ops[j].append(item)\n\n    # Deconstruct dequeue ops.\n    flat_dequeue_ops = []\n    for l in dequeue_ops:\n      flat_dequeue_ops.extend(l)\n\n    dequeue_ops_by_name = {}\n    pos = 0\n    for name in self._names:\n      dequeue_ops_by_name[name] = dequeue_ops[pos:pos +\n                                              len(self._tensors[name])]\n      pos += len(self._tensors[name])\n\n    def _call_host_fn(fn, *args, **kw):\n      context = CatchInvalidHostcallFunctions()\n      context.Enter()\n      result = fn(*args, **kw)\n      context.Exit()\n      context.ExitResult(result)\n      return result\n\n    # It is assumed evaluation always happens on single host TPU system. So,\n    # place all ops on tpu host if possible.\n    #\n    # TODO(jhseu): Evaluate whether this is right for summaries.\n    with tf.compat.v1.device(\n        self._ctx.tpu_host_placement_function(replica_id=0)):\n      for name in self._names:\n        dequeue_ops = dequeue_ops_by_name[name]\n        for i, item in enumerate(dequeue_ops):\n          # TODO(xiejw): Make the specification of the outfeed combinaton\n          # function more explicit and well-documented.  We may want to give the\n          # user the option of concatenating along any axis.\n          if (self._ctx.config.tpu_config.per_host_input_for_training is\n              tpu_config.InputPipelineConfig.BROADCAST):\n            # If the infeed is in BROADCAST mode (each core recieving the same\n            # input), then we assume that the cores also produce identical\n            # copies of the same output, and we simply take the output from\n            # the first core.  This mode is used by Mesh-TensorFlow.\n            with tf.control_dependencies(dequeue_ops[i]):\n              dequeue_ops[i] = tf.identity(dequeue_ops[i][0])\n          else:\n            if dequeue_ops[i][0].shape.ndims == 0:\n              raise RuntimeError(\n                  'All tensors outfed from TPU should preserve batch size '\n                  'dimension, but got scalar {}'.format(dequeue_ops[i][0]))\n            # Assume that the input has been batch-split and that axis 0 of the\n            # output tensors represents the batch size.  Concatenate along\n            # the axis 0 to re-combine the batch.\n            dequeue_ops[i] = tf.concat(dequeue_ops[i], axis=0)\n\n        if self._tensor_keys[name] is not None:\n          # The user-provided eval_metrics[1] is a dict.\n          dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops))\n          try:\n            ret[name] = _call_host_fn(self._host_fns[name], **dequeue_ops)\n          except TypeError as e:\n            tf.compat.v1.logging.warn(\n                'Exception while calling %s: %s. It is likely the tensors '\n                '(%s[1]) do not match the '\n                'function\\'s arguments', name, e, name)\n            raise\n        else:\n          ret[name] = _call_host_fn(self._host_fns[name], *dequeue_ops)\n\n    # force all dequeue operations to be run if not consumed by the host calls\n    ret['__force_dequeue'] = tf.group(*flat_dequeue_ops)\n    return ret\n\n\nclass _OutfeedHostCallHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Hook to run host calls when use_tpu=False.\"\"\"\n\n  def __init__(self, tensors):\n    self._tensors = tensors\n\n  def begin(self):\n    # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than\n    # create a separate hook to guarantee execution order, because summaries\n    # need to be initialized before the outfeed thread starts.\n    # TODO(jhseu): Make a wrapper hook instead?\n    self._init_ops = summary_ops_v2.summary_writer_initializer_op()\n    # Get all the writer resources from the initializer, so we know what to\n    # flush.\n    self._finalize_ops = []\n    for op in self._init_ops:\n      self._finalize_ops.append(\n        summary_ops_v2.legacy_raw_flush(writer=op.inputs[0]))\n\n  def after_create_session(self, session, coord):\n    session.run(self._init_ops)\n\n  def before_run(self, run_context):\n    return tf.compat.v1.train.SessionRunArgs(self._tensors)\n\n  def end(self, session):\n    session.run(self._finalize_ops)\n\n\nclass _NotSaver(object):\n  \"\"\"What to pass instead of a saver object if you don't want saving.\"\"\"\n\n  def __init__(self, message):\n    self._message = message\n\n  def save(self, *args, **kwargs):\n    del args, kwargs\n    tf.compat.v1.logging.info(self._message)\n\n\nclass ExamplesPerSecondHook(tf.compat.v1.train.StepCounterHook):\n  \"\"\"Calculate and report global_step/sec and examples/sec during runtime.\"\"\"\n\n  def __init__(self,\n               batch_size,\n               every_n_steps=100,\n               every_n_secs=None,\n               output_dir=None,\n               summary_writer=None):\n    self._batch_size = batch_size\n    super(ExamplesPerSecondHook, self).__init__(\n        every_n_steps=every_n_steps,\n        every_n_secs=every_n_secs,\n        output_dir=output_dir,\n        summary_writer=summary_writer)\n\n  def _log_and_record(self, elapsed_steps, elapsed_time, global_step):\n    global_step_per_sec = elapsed_steps / elapsed_time\n    examples_per_sec = self._batch_size * global_step_per_sec\n    if self._summary_writer is not None:\n      global_step_summary = Summary(value=[\n          Summary.Value(\n              tag='global_step/sec', simple_value=global_step_per_sec)\n      ])\n      example_summary = Summary(value=[\n          Summary.Value(tag='examples/sec', simple_value=examples_per_sec)\n      ])\n      self._summary_writer.add_summary(global_step_summary, global_step)\n      self._summary_writer.add_summary(example_summary, global_step)\n    tf.compat.v1.logging.info('global_step/sec: %g', global_step_per_sec)\n    tf.compat.v1.logging.info('examples/sec: %g', examples_per_sec)\n\n\nclass InstallSignalHandlerHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Change SIGINT (CTRL^C) handler to force quit the process.\n\n  The default behavior often results in hanging processes.\n  The original handler is restored after training/evaluation.\n  \"\"\"\n\n  def __init__(self):\n    self._signal_fn = signal.getsignal(signal.SIGINT)\n\n  def before_run(self, run_context):\n    signal.signal(signal.SIGINT, signal.SIG_DFL)\n\n  def end(self, session):\n    signal.signal(signal.SIGINT, self._signal_fn)\n\n\nclass ExportSavedModelApiVersion(enum.Enum):\n  V1 = 1\n  V2 = 2\n\n\nclass BatchConfig(\n    collections.namedtuple('BatchConfig', [\n        'num_batch_threads', 'max_batch_size', 'batch_timeout_micros',\n        'allowed_batch_sizes', 'max_enqueued_batches'\n    ])):\n  \"\"\"Class to handle config inputs into the batching function.\"\"\"\n\n  def __new__(cls,\n              num_batch_threads,\n              max_batch_size,\n              batch_timeout_micros,\n              allowed_batch_sizes,\n              max_enqueued_batches=100):\n    \"\"\"Creates an BatchConfig instance.\n\n    Args:\n     num_batch_threads: Number of scheduling threads for processing batches of\n       work. Determines the number of batches processed in parallel.\n      max_batch_size: Batch sizes will never be bigger than this.\n      batch_timeout_micros: Maximum number of microseconds to wait before\n        outputting an incomplete batch.\n      allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,\n        does nothing. Otherwise, supplies a list of batch sizes, causing the op\n        to pad batches up to one of those sizes. The entries must increase\n        monotonically, and the final entry must equal max_batch_size.\n      max_enqueued_batches: The maximum depth of the batch queue. Defaults to\n        100.\n\n    Returns:\n      An BatchConfig instance.\n    \"\"\"\n    return super(BatchConfig, cls).__new__(\n        cls,\n        num_batch_threads=num_batch_threads,\n        max_batch_size=max_batch_size,\n        batch_timeout_micros=batch_timeout_micros,\n        allowed_batch_sizes=allowed_batch_sizes,\n        max_enqueued_batches=max_enqueued_batches)\n\n\n@estimator_export(v1=['estimator.tpu.TPUEstimator'])\nclass TPUEstimator(estimator_lib.Estimator):\n  \"\"\"Estimator with TPU support.\n\n  TPUEstimator also supports training on CPU and GPU. You don't need to define\n  a separate `tf.estimator.Estimator`.\n\n  TPUEstimator handles many of the details of running on TPU devices, such as\n  replicating inputs and models for each core, and returning to host\n  periodically to run hooks.\n\n  TPUEstimator transforms a global batch size in params to a per-shard batch\n  size when calling the `input_fn` and `model_fn`. Users should specify\n  global batch size in constructor, and then get the batch size for each shard\n  in `input_fn` and `model_fn` by `params['batch_size']`.\n\n  - For training, `model_fn` gets per-core batch size; `input_fn` may get\n    per-core or per-host batch size depending on `per_host_input_for_training`\n    in `TPUConfig` (See docstring for TPUConfig for details).\n\n  - For evaluation and prediction, `model_fn` gets per-core batch size and\n    `input_fn` get per-host batch size.\n\n  Evaluation\n  ==========\n\n  `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics`\n  for TPU evaluation. If eval_on_tpu is False, the evaluation will execute on\n  CPU or GPU; in this case the following discussion on TPU evaluation does not\n  apply.\n\n  `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where\n  `tensors` could be a list of any nested structure of `Tensor`s (See\n  `TPUEstimatorSpec` for details).  `metric_fn` takes the `tensors` and returns\n  a dict from metric string name to the result of calling a metric function,\n  namely a `(metric_tensor, update_op)` tuple.\n\n  One can set `use_tpu` to `False` for testing. All training, evaluation, and\n  predict will be executed on CPU. `input_fn` and `model_fn` will receive\n  `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`.\n\n  Current limitations:\n  --------------------\n\n  1. TPU evaluation only works on a single host (one TPU worker) except\n     BROADCAST mode.\n\n  2. `input_fn` for evaluation should **NOT** raise an end-of-input exception\n     (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all\n     batches should have the same size.\n\n  Example (MNIST):\n  ----------------\n\n  ```\n  # The metric Fn which runs on CPU.\n  def metric_fn(labels, logits):\n    predictions = tf.argmax(logits, 1)\n    return {\n      'accuracy': tf.compat.v1.metrics.precision(\n          labels=labels, predictions=predictions),\n    }\n\n  # Your model Fn which runs on TPU (eval_metrics is list in this example)\n  def model_fn(features, labels, mode, config, params):\n    ...\n    logits = ...\n\n    if mode = tf.estimator.ModeKeys.EVAL:\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode,\n          loss=loss,\n          eval_metrics=(metric_fn, [labels, logits]))\n\n  # or specify the eval_metrics tensors as dict.\n  def model_fn(features, labels, mode, config, params):\n    ...\n    final_layer_output = ...\n\n    if mode = tf.estimator.ModeKeys.EVAL:\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode,\n          loss=loss,\n          eval_metrics=(metric_fn, {\n              'labels': labels,\n              'logits': final_layer_output,\n          }))\n  ```\n\n  Prediction\n  ==========\n\n  Prediction on TPU is an experimental feature to support large batch inference.\n  It is not designed for latency-critical system. In addition, due to some\n  usability issues, for prediction with small dataset, CPU `.predict`, i.e.,\n  creating a new `TPUEstimator` instance with `use_tpu=False`, might be more\n  convenient.\n\n  Note: In contrast to TPU training/evaluation, the `input_fn` for prediction\n  *should* raise an end-of-input exception (`OutOfRangeError` or\n  `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be\n  precise, the ops created by `input_fn` produce one batch of the data.\n  The `predict()` API processes one batch at a time. When reaching the end of\n  the data source, an end-of-input exception should be raised by one of these\n  operations. The user usually does not need to do this manually. As long as the\n  dataset is not repeated forever, the `tf.data` API will raise an end-of-input\n  exception automatically after the last batch has been produced.\n\n  Note: Estimator.predict returns a Python generator. Please consume all the\n  data from the generator so that TPUEstimator can shutdown the TPU system\n  properly for user.\n\n  Current limitations:\n  --------------------\n  1. TPU prediction only works on a single host (one TPU worker).\n\n  2. `input_fn` must return a `Dataset` instance rather than `features`. In\n  fact, .train() and .evaluate() also support Dataset as return value.\n\n  Example (MNIST):\n  ----------------\n  ```\n  height = 32\n  width = 32\n  total_examples = 100\n\n  def predict_input_fn(params):\n    batch_size = params['batch_size']\n\n    images = tf.random.uniform(\n        [total_examples, height, width, 3], minval=-1, maxval=1)\n\n    dataset = tf.data.Dataset.from_tensor_slices(images)\n    dataset = dataset.map(lambda images: {'image': images})\n\n    dataset = dataset.batch(batch_size)\n    return dataset\n\n  def model_fn(features, labels, params, mode):\n     # Generate predictions, called 'output', from features['image']\n\n    if mode == tf.estimator.ModeKeys.PREDICT:\n      return tf.contrib.tpu.TPUEstimatorSpec(\n          mode=mode,\n          predictions={\n              'predictions': output,\n              'is_padding': features['is_padding']\n          })\n\n  tpu_est = TPUEstimator(\n      model_fn=model_fn,\n      ...,\n      predict_batch_size=16)\n\n  # Fully consume the generator so that TPUEstimator can shutdown the TPU\n  # system.\n  for item in tpu_est.predict(input_fn=input_fn):\n    # Filter out item if the `is_padding` is 1.\n    # Process the 'predictions'\n  ```\n\n  Exporting\n  =========\n\n  `export_saved_model` exports 2 metagraphs, one with `saved_model.SERVING`, and\n  another with `saved_model.SERVING` and `saved_model.TPU` tags. At serving\n  time, these tags are used to select the appropriate metagraph to load.\n\n  Before running the graph on TPU, the TPU system needs to be initialized. If\n  TensorFlow Serving model-server is used, this is done automatically. If not,\n  please use `session.run(tpu.initialize_system())`.\n\n  There are two versions of the API: 1 or 2.\n\n  In V1, the exported CPU graph is `model_fn` as it is. The exported TPU graph\n  wraps `tpu.rewrite()` and `TPUPartitionedCallOp` around `model_fn` so\n  `model_fn` is on TPU by default. To place ops on CPU,\n  `tpu_replication.outside_compilation(host_call, logits)` can be used.\n\n  Example:\n  ----------------\n\n  ```\n  def model_fn(features, labels, mode, config, params):\n    ...\n    logits = ...\n    export_outputs = {\n      'logits': export_output_lib.PredictOutput(\n        {'logits': logits})\n    }\n\n    def host_call(logits):\n      class_ids = math_ops.argmax(logits)\n      classes = string_ops.as_string(class_ids)\n      export_outputs['classes'] =\n        export_output_lib.ClassificationOutput(classes=classes)\n\n    tpu_replication.outside_compilation(host_call, logits)\n\n    ...\n  ```\n\n  In V2, `export_saved_model()` sets up `params['use_tpu']` flag to let the user\n  know if the code is exporting to TPU (or not). When `params['use_tpu']` is\n  `True`, users need to call `tpu.rewrite()`, `TPUPartitionedCallOp` and/or\n  `batch_function()`.\n\n  TIP: V2 is recommended as it is more flexible (eg: batching, etc).\n\n  @compatibility(TF2)\n  TPU Estimator manages its own TensorFlow graph and session, so it is not\n  compatible with TF2 behaviors. We recommend that you migrate to the newer\n  `tf.distribute.TPUStrategy`. See the\n  [TPU guide](https://www.tensorflow.org/guide/tpu) for details.\n  @end_compatibility\n  \"\"\"\n\n  def __init__(self,\n               model_fn=None,\n               model_dir=None,\n               config=None,\n               params=None,\n               use_tpu=True,\n               train_batch_size=None,\n               eval_batch_size=None,\n               predict_batch_size=None,\n               batch_axis=None,\n               eval_on_tpu=True,\n               export_to_tpu=True,\n               export_to_cpu=True,\n               warm_start_from=None,\n               embedding_config_spec=None,\n               export_saved_model_api_version=ExportSavedModelApiVersion.V1):\n    \"\"\"Constructs an `TPUEstimator` instance.\n\n    Args:\n      model_fn: Model function as required by `Estimator` which returns\n        EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks',\n        and `prediction_hooks` must not capure any TPU Tensor inside the\n        model_fn.\n      model_dir: Directory to save model parameters, graph and etc. This can\n        also be used to load checkpoints from the directory into a estimator to\n        continue training a previously saved model. If `None`, the model_dir in\n        `config` will be used if set. If both are set, they must be same. If\n        both are `None`, a temporary directory will be used.\n      config: An `tpu_config.RunConfig` configuration object. Cannot be `None`.\n      params: An optional `dict` of hyper parameters that will be passed into\n        `input_fn` and `model_fn`.  Keys are names of parameters, values are\n        basic python types. There are reserved keys for `TPUEstimator`,\n        including 'batch_size'.\n      use_tpu: A bool indicating whether TPU support is enabled. Currently, -\n        TPU training and evaluation respect this bit, but eval_on_tpu can\n        override execution of eval. See below.\n      train_batch_size: An int representing the global training batch size.\n        TPUEstimator transforms this global batch size to a per-shard batch\n        size, as params['batch_size'], when calling `input_fn` and `model_fn`.\n        Cannot be `None` if `use_tpu` is `True`. Must be divisible by total\n        number of replicas.\n      eval_batch_size: An int representing evaluation batch size. Must be\n        divisible by total number of replicas.\n      predict_batch_size: An int representing the prediction batch size. Must be\n        divisible by total number of replicas.\n      batch_axis: A python tuple of int values describing how each tensor\n        produced by the Estimator `input_fn` should be split across the TPU\n        compute shards. For example, if your input_fn produced (images, labels)\n        where the images tensor is in `HWCN` format, your shard dimensions would\n        be [3, 0], where 3 corresponds to the `N` dimension of your images\n        Tensor, and 0 corresponds to the dimension along which to split the\n        labels to match up with the corresponding images. If None is supplied,\n        and per_host_input_for_training is True, batches will be sharded based\n        on the major dimension. If tpu_config.per_host_input_for_training is\n        False or `PER_HOST_V2`, batch_axis is ignored.\n      eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the\n        model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`.\n      export_to_tpu: If True, `export_saved_model()` exports a metagraph for\n        serving on TPU. Note that unsupported export modes such as EVAL will be\n        ignored. For those modes, only a CPU model will be exported. Currently,\n        export_to_tpu only supports PREDICT.\n      export_to_cpu: If True, `export_saved_model()` exports a metagraph for\n        serving on CPU.\n      warm_start_from: Optional string filepath to a checkpoint or SavedModel to\n        warm-start from, or a `tf.estimator.WarmStartSettings` object to fully\n        configure warm-starting.  If the string filepath is provided instead of\n        a `WarmStartSettings`, then all variables are warm-started, and it is\n        assumed that vocabularies and Tensor names are unchanged.\n      embedding_config_spec: Optional EmbeddingConfigSpec instance to support\n        using TPU embedding.\n      export_saved_model_api_version: an integer: 1 or 2. 1 corresponds to V1,\n        2 corresponds to V2. (Defaults to V1). With\n        V1, `export_saved_model()` adds rewrite() and TPUPartitionedCallOp() for\n        user; while in v2, user is expected to add rewrite(),\n        TPUPartitionedCallOp() etc in their model_fn.\n\n    Raises:\n      ValueError: `params` has reserved keys already.\n    \"\"\"\n    if config is None or not isinstance(config, tpu_config.RunConfig):\n      raise ValueError(\n          '`config` must be provided with type `tpu_config.RunConfig`')\n\n    if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS):\n      raise ValueError('{} are reserved keys but existed in params {}.'.format(\n          _RESERVED_PARAMS_KEYS, params))\n\n    if use_tpu:\n      # Perform some very basic validations. More validations will be found in\n      # _InternalTPUContext.\n      if train_batch_size is None:\n        raise ValueError('`train_batch_size` cannot be `None`')\n      util_lib.check_positive_integer(train_batch_size, 'train_batch_size')\n\n      if (config.tpu_config.per_host_input_for_training is\n          tpu_config.InputPipelineConfig.PER_SHARD_V1 and\n          config.tpu_config.num_cores_per_replica):\n        raise ValueError(\n            'Model parallelism only supports per host input for training. '\n            'Please adjust TPURunconfig.per_host_input_for_training.')\n\n      if eval_batch_size is not None:\n        util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size')\n\n      if predict_batch_size is not None:\n        util_lib.check_positive_integer(predict_batch_size,\n                                        'predict_batch_size')\n\n      if embedding_config_spec:\n        if (config.tpu_config.per_host_input_for_training not in (\n            tpu_config.InputPipelineConfig.PER_HOST_V1,\n            tpu_config.InputPipelineConfig.PER_HOST_V2)):\n          raise ValueError('Only PER_HOST_V1 and PER_HOST_V2 is supported when '\n                           'using TPU Embedding; got {}.'.format(\n                               config.tpu_config.per_host_input_for_training))\n        self._embedding_from_feature_columns = (\n            embedding_config_spec.feature_columns is not None)\n\n    if (not (use_tpu and eval_on_tpu) and embedding_config_spec and\n        embedding_config_spec.partition_strategy == 'mod'):\n      raise ValueError('Mod sharding of embedding tables not supported on '\n                       'CPU.')\n    _tpu_estimator_gauge.get_cell().set(True)\n    # Verifies the model_fn signature according to Estimator framework.\n    estimator_lib._verify_model_fn_args(model_fn, params)  # pylint: disable=protected-access\n    # We cannot store config and params in this constructor as parent\n    # constructor might change them, such as assigning a temp dir for\n    # config.model_dir.\n    model_function = self._augment_model_fn(model_fn, batch_axis)\n\n    # Overwrite log_step_count_steps to disable TensorLoggingHook and\n    # StepCounterHook from being created in Estimator. TPUEstimator already\n    # added equivalent hooks in _augment_model_fn above.\n    self._log_every_n_steps = config.log_step_count_steps\n    config = config.replace(log_step_count_steps=None)\n\n    # Passing non-None params as wrapped model_fn has it.\n    params = params or {}\n    super(TPUEstimator, self).__init__(\n        model_fn=model_function,\n        model_dir=model_dir,\n        config=config,\n        params=params,\n        warm_start_from=warm_start_from)\n    self._iterations_per_training_loop = util_lib.parse_iterations_per_loop(\n        self._config.tpu_config.iterations_per_loop)\n    # In absence of an explicit `log_every_n_secs` config, if the\n    # `iterations_per_loop` value is specified as time in seconds, enable\n    # logging every n secs based on the `iterations_per_loop` value. A trade-off\n    # avoiding API change on the current release.\n    # TODO(henrytan): add `log_every_n_secs` to RunConfig.\n    if self._iterations_per_training_loop.unit == 'seconds':\n      self._log_every_n_secs = self._iterations_per_training_loop.value\n      self._log_every_n_steps = None\n    elif self._iterations_per_training_loop.unit == 'count':\n      if self._log_every_n_steps is not None:\n        # Each session.run() lasts for iterations_per_loop. We can't log\n        # in-between a session.run(), and we can only log after the\n        # `iterations_per_loop` steps, so we can only approximate. If a user\n        # requests to log every N steps, we actually want to roughly log every\n        # N / `iterations_per_loop` steps to match the original intention.\n        self._log_every_n_steps = (\n            int(\n                math.ceil(\n                    float(self._log_every_n_steps) /\n                    self._iterations_per_training_loop.value)))\n      self._log_every_n_secs = None\n    else:\n      assert False, ('Invalid TPUConfig `iterations_per_loop` value. '\n                     'Indicates a bug in `iterations_per_loop` '\n                     'parsing.')\n\n    # All properties passed to _InternalTPUContext are immutable.\n    # pylint: disable=protected-access\n    self._ctx = tpu_context._get_tpu_context(self._config, train_batch_size,\n                                             eval_batch_size,\n                                             predict_batch_size, use_tpu,\n                                             eval_on_tpu, embedding_config_spec)\n\n    self._export_to_cpu = export_to_cpu\n    self._export_to_tpu = export_to_tpu\n\n    if not (isinstance(export_saved_model_api_version,\n                       ExportSavedModelApiVersion)\n            or export_saved_model_api_version == 1\n            or export_saved_model_api_version == 2):\n      raise ValueError('export_saved_model_api_version should be 1 or 2; '\n                       'got {}.'.format(\n                           export_saved_model_api_version))\n    self._export_saved_model_api_version = export_saved_model_api_version\n    self._is_input_fn_invoked = None\n\n    self._rendezvous = {}\n\n  def _add_meta_graph_for_mode(self,\n                               builder,\n                               input_receiver_fn_map,\n                               checkpoint_path,\n                               save_variables=True,\n                               mode=model_fn_lib.ModeKeys.PREDICT,\n                               export_tags=None,\n                               check_variables=True,\n                               strip_default_attrs=True):\n    if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT:\n      tf.compat.v1.logging.warn(\n          'TPUEstimator only handles mode PREDICT for exporting '\n          'when `export_to_tpu` is `True`; Mode {} will be ignored '\n          'for TPU.'.format(mode))\n\n    if not self._export_to_cpu and not self._export_to_tpu:\n      raise ValueError('One of export_to_cpu and export_to_tpu must be true.')\n\n    if self._export_to_cpu:\n      (super(TPUEstimator, self)._add_meta_graph_for_mode(\n          builder,\n          input_receiver_fn_map,\n          checkpoint_path,\n          save_variables,\n          mode=mode,\n          export_tags=export_tags,\n          check_variables=check_variables,\n          strip_default_attrs=strip_default_attrs))\n\n    if self._export_to_tpu and mode == model_fn_lib.ModeKeys.PREDICT:\n      input_receiver_fn_map = {\n          _INFERENCE_ON_TPU_MODE: input_receiver_fn_map[mode]\n      }\n      export_tags = [tf.saved_model.SERVING, tf.saved_model.TPU]\n      mode = _INFERENCE_ON_TPU_MODE\n\n      # See b/110052256 for why `check_variables` is `False`.\n      if not self._export_to_cpu:\n        check_variables = save_variables = True\n      else:\n        check_variables = save_variables = False\n      (super(TPUEstimator, self)._add_meta_graph_for_mode(\n          builder,\n          input_receiver_fn_map,\n          checkpoint_path,\n          save_variables=save_variables,\n          mode=mode,\n          export_tags=export_tags,\n          check_variables=check_variables,\n          strip_default_attrs=strip_default_attrs))\n\n  def _call_model_fn(self, features, labels, mode, config):\n    if mode == _INFERENCE_ON_TPU_MODE:\n      context = tpu._TPUInferenceContext('tpu_inference', check_ops=False)\n      try:\n        context.Enter()\n        if (\n            (self._export_saved_model_api_version ==\n             ExportSavedModelApiVersion.V1)\n            or self._export_saved_model_api_version == 1):\n          result = self._call_model_fn_for_inference(features, labels, mode,\n                                                     config)\n        else:\n          result = super(TPUEstimator,\n                         self)._call_model_fn(features, labels, mode, config)\n      finally:\n        context.Exit()\n      return result\n    else:\n      return super(TPUEstimator, self)._call_model_fn(features, labels, mode,\n                                                      config)\n\n  def _call_model_fn_for_inference(self, features, labels, mode, config):\n    \"\"\"Wraps `_call_model_fn` for `export_saved_model`.\"\"\"\n    if mode != _INFERENCE_ON_TPU_MODE:\n      raise ValueError('mode must be {}; '\n                       'got {}.'.format(_INFERENCE_ON_TPU_MODE, mode))\n    return model_fn_inference_on_tpu(\n        self._model_fn,\n        features,\n        labels,\n        config,\n        self._params,\n        batch_config=None)\n\n  def _create_global_step(self, graph):\n    \"\"\"Creates a global step suitable for TPUs.\n\n    Args:\n      graph: The graph in which to create the global step.\n\n    Returns:\n      A global step `Tensor`.\n\n    Raises:\n      ValueError: if the global step tensor is already defined.\n    \"\"\"\n    return _create_global_step(graph)\n\n  def _convert_train_steps_to_hooks(self, steps, max_steps):\n    with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx:\n      if ctx.is_running_on_cpu():\n        return super(TPUEstimator,\n                     self)._convert_train_steps_to_hooks(steps, max_steps)\n\n    # On TPU.\n    if steps is None and max_steps is None:\n      raise ValueError(\n          'For TPU training, one of `steps` or `max_steps` must be set. '\n          'Cannot be both `None`.')\n\n    # Estimator.train has explicit positiveness check.\n    if steps is not None:\n      util_lib.check_positive_integer(steps, 'Train steps')\n    if max_steps is not None:\n      util_lib.check_positive_integer(max_steps, 'Train max_steps')\n\n    return [\n        _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)\n    ]\n\n  def _convert_eval_steps_to_hooks(self, steps):\n    with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx:\n      if ctx.is_running_on_cpu():\n        return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps)\n\n    if steps is None:\n      raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.')\n\n    util_lib.check_positive_integer(steps, 'Eval steps')\n\n    return [\n        evaluation._StopAfterNEvalsHook(  # pylint: disable=protected-access\n            num_evals=steps),\n        _SetEvalIterationsHook(steps)\n    ]\n\n  def _call_input_fn(self, input_fn, mode, input_context=None):\n    \"\"\"Calls the input function.\n\n    Args:\n      input_fn: The input function.\n      mode: ModeKeys\n      input_context: Optional instance of `tf.distribute.InputContext`.\n\n    Returns:\n      In TPU mode, returns an input_fn to be called later in model_fn.\n      Otherwise, calls the input_fn and returns either fatures or\n        (features, labels).\n\n    Raises:\n      ValueError: if input_fn takes invalid arguments or does not have `params`.\n    \"\"\"\n    input_fn_args = function_utils.fn_args(input_fn)\n    config = self.config  # a deep copy.\n    kwargs = {}\n    if 'params' in input_fn_args:\n      kwargs['params'] = self.params  # a deep copy.\n    else:\n      raise ValueError('input_fn ({}) does not include params argument, '\n                       'required by TPUEstimator to pass batch size as '\n                       'params[\"batch_size\"]'.format(input_fn))\n    if 'config' in input_fn_args:\n      kwargs['config'] = config\n\n    if 'mode' in input_fn_args:\n      kwargs['mode'] = mode\n\n    if 'input_context' in input_fn_args:\n      kwargs['input_context'] = input_context\n\n    # Records the fact input_fn has been invoked.\n    self._is_input_fn_invoked = True\n\n    with self._ctx.with_mode(mode) as ctx:\n      if (ctx.is_running_on_cpu() and\n          ctx.is_input_slice_broadcast_to_all_cores()):\n        raise ValueError('Invalid TPUConfig `eval_training_input_configuration`'\n                         ' value. SLICED mode only works on use_tpu = True.')\n      # Setting the batch size in params first. This helps user to have same\n      # input_fn for use_tpu=True/False.\n      batch_size_for_input_fn = ctx.batch_size_for_input_fn\n      if batch_size_for_input_fn is not None:\n        _add_item_to_params(kwargs['params'], _BATCH_SIZE_KEY,\n                            batch_size_for_input_fn)\n\n      # For export_saved_model, input_fn is never passed to Estimator. So,\n      # `is_export_mode` must be False.\n      if ctx.is_running_on_cpu(is_export_mode=False):\n        with tf.compat.v1.device('/device:CPU:0'):\n          return input_fn(**kwargs)\n\n      # For TPU computation, input_fn should be invoked in a tf.while_loop for\n      # performance. While constructing the tf.while_loop, the structure of\n      # inputs returned by the `input_fn` needs to be recorded. The structure\n      # includes whether features or labels is dict or single Tensor, dict keys,\n      # tensor shapes, and dtypes. The recorded structure is used to create the\n      # infeed dequeue ops, which must be wrapped and passed as a Fn, called\n      # inside the TPU computation, as the TPU computation is wrapped inside a\n      # tf.while_loop also. So, we either pass input_fn to model_fn or pass\n      # dequeue_fn to model_fn. Here, `input_fn` is passed directly as\n      # `features` in `model_fn` signature.\n      def _input_fn(ctx):\n        _add_item_to_params(kwargs['params'], _CTX_KEY, ctx)\n        return input_fn(**kwargs)\n\n      return _input_fn\n\n  def _validate_features_in_predict_input(self, result):\n    \"\"\"Skip the validation.\n\n    For TPUEstimator, we do not need to check the result type. `_InputPipeline`\n    has stronger check. Parent class's check generates confusing warning msg.\n\n    Args:\n      result: `features` returned by input_fn.\n    \"\"\"\n    pass\n\n  def train(self,\n            input_fn,\n            hooks=None,\n            steps=None,\n            max_steps=None,\n            saving_listeners=None):\n    rendezvous = error_handling.ErrorRendezvous(num_sources=3)\n    self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous\n    try:\n      return super(TPUEstimator, self).train(\n          input_fn=input_fn,\n          hooks=hooks,\n          steps=steps,\n          max_steps=max_steps,\n          saving_listeners=saving_listeners)\n    except Exception:  # pylint: disable=broad-except\n      rendezvous.record_error('training_loop', sys.exc_info())\n    finally:\n      rendezvous.record_done('training_loop')\n      rendezvous.raise_errors()\n\n  def evaluate(self,\n               input_fn,\n               steps=None,\n               hooks=None,\n               checkpoint_path=None,\n               name=None):\n    rendezvous = error_handling.ErrorRendezvous(num_sources=3)\n    self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous\n    try:\n      return super(TPUEstimator, self).evaluate(\n          input_fn,\n          steps=steps,\n          hooks=hooks,\n          checkpoint_path=checkpoint_path,\n          name=name)\n    except Exception:  # pylint: disable=broad-except\n      rendezvous.record_error('evaluation_loop', sys.exc_info())\n    finally:\n      rendezvous.record_done('evaluation_loop')\n      rendezvous.raise_errors()\n\n  def predict(self,\n              input_fn,\n              predict_keys=None,\n              hooks=None,\n              checkpoint_path=None,\n              yield_single_examples=True):\n    rendezvous = error_handling.ErrorRendezvous(num_sources=3)\n    self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous\n    try:\n      for result in super(TPUEstimator, self).predict(\n          input_fn=input_fn,\n          predict_keys=predict_keys,\n          hooks=hooks,\n          checkpoint_path=checkpoint_path,\n          yield_single_examples=yield_single_examples):\n        yield result\n    except Exception:  # pylint: disable=broad-except\n      rendezvous.record_error('prediction_loop', sys.exc_info())\n    finally:\n      rendezvous.record_done('prediction_loop')\n      rendezvous.raise_errors()\n\n    rendezvous.record_done('prediction_loop')\n    rendezvous.raise_errors()\n\n  def _augment_model_fn(self, model_fn, batch_axis):\n    \"\"\"Returns a new model_fn, which wraps the TPU support.\"\"\"\n\n    def _model_fn(features, labels, mode, config, params):\n      \"\"\"A Estimator `model_fn` for TPUEstimator.\"\"\"\n\n      # `input_fn` is called in `train()`, `evaluate()`, and `predict()`,\n      # but not in `export_saved_model()`.\n      if self._is_input_fn_invoked:\n        is_export_mode = False\n      else:\n        is_export_mode = True\n\n      # Clear the bit.\n      self._is_input_fn_invoked = None\n\n      if is_export_mode:\n        if mode == _INFERENCE_ON_TPU_MODE:\n          _add_item_to_params(params, _USE_TPU_KEY, True)\n          mode = model_fn_lib.ModeKeys.PREDICT\n        else:\n          _add_item_to_params(params, _USE_TPU_KEY, False)\n\n      with self._ctx.with_mode(mode) as ctx:\n        model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx)\n\n        # examples_hook is added to training_hooks for both CPU and TPU\n        # execution.\n        if (self._log_every_n_steps is not None or\n            self._log_every_n_secs is not None):\n          examples_hook = ExamplesPerSecondHook(\n              ctx.global_batch_size,\n              # pylint:disable=g-long-ternary\n              output_dir=(self.model_dir\n                          if not config or config.save_summary_steps else None),\n              # pylint:enable=g-long-ternary\n              every_n_steps=self._log_every_n_steps,\n              every_n_secs=self._log_every_n_secs)\n\n        if ctx.is_running_on_cpu(is_export_mode=is_export_mode):\n          tf.compat.v1.logging.info('Running %s on CPU/GPU', mode)\n          estimator_spec = model_fn_wrapper.call_without_tpu(\n              features, labels, is_export_mode=is_export_mode)\n          if (self._log_every_n_steps is not None or\n              self._log_every_n_secs is not None):\n            estimator_spec = estimator_spec._replace(\n                training_hooks=estimator_spec.training_hooks + (examples_hook,))\n          return estimator_spec\n\n        assert labels is None, '`labels` passed to `model_fn` must be `None`.'\n        # TPUEstimator._call_input_fn passes `input_fn` as features to here.\n        assert callable(features), '`input_fn` is not callable.'\n        input_fn = features\n\n        tpu_init_ops = []\n        if ctx.embedding_config and mode == model_fn_lib.ModeKeys.TRAIN:\n          dummy_table_variables, dummy_table_variables_init = (\n              tpu_embedding_gradient.create_dummy_table_variables(\n                  ctx.embedding_config.tpu_embedding))\n          ctx.embedding_config.dummy_table_variables = dummy_table_variables\n          tpu_init_ops.append(dummy_table_variables_init)\n\n        input_holders = _InputPipeline(input_fn, batch_axis, ctx)\n        enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = (\n            input_holders.generate_infeed_enqueue_ops_and_dequeue_fn())\n\n        graph = tf.compat.v1.get_default_graph()\n        for enqueue_op in enqueue_ops:\n          if isinstance(enqueue_op, list):\n            graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op)\n          else:\n            graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op)\n\n        if mode == model_fn_lib.ModeKeys.TRAIN:\n          compile_op, loss, host_call, scaffold_fn, training_hooks = (\n              _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))\n          has_saver_hook = training_hooks and any(\n              isinstance(hook, tf.compat.v1.train.CheckpointSaverHook)\n              for hook in training_hooks)\n          if ctx.embedding_config:\n            g = tf.compat.v1.get_default_graph()\n            table_to_config_dict = (\n                ctx.embedding_config.tpu_embedding.table_to_config_dict)\n            optimization_parameters = (\n                ctx.embedding_config.tpu_embedding.optimization_parameters)\n            if self._embedding_from_feature_columns:\n              embedding_variable_name_by_table, slot_variable_names_by_table = (\n                  _tpu_estimator_embedding.get_full_variable_names(\n                      g, table_to_config_dict, optimization_parameters))\n            else:\n              embedding_variable_name_by_table = None\n              slot_variable_names_by_table = None\n            embedding_variables_and_ops = (\n                ctx.embedding_config.tpu_embedding.create_variables_and_ops(\n                    embedding_variable_name_by_table,\n                    slot_variable_names_by_table))\n            tpu_init_ops.extend(embedding_variables_and_ops.load_ops())\n          # scaffold_fn must be called after variables for TPU embedding has\n          # been created on CPU, as user might reinitialize those from some\n          # checkpoint within scaffold_fn.\n          scaffold = _get_scaffold(scaffold_fn)\n\n          host_ops = host_call.create_tpu_hostcall()\n\n          shutdown_hooks = []\n          shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE',\n                                         'reset_computation')\n          if shutdown_mode:\n            if shutdown_mode == 'shutdown_worker':\n              finalizer_hooks = [\n                  session_support.ShutdownLameWorkers(),\n              ]\n            elif shutdown_mode == 'shutdown_all_workers':\n              finalizer_hooks = [\n                  session_support.ShutdownAllWorkers(),\n              ]\n            elif shutdown_mode == 'reset_computation':\n              finalizer_hooks = [\n                  session_support.ResetComputation(),\n              ]\n            elif not shutdown_mode:\n              finalizer_hooks = []\n            else:\n              raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE \"%s\"' %\n                               shutdown_mode)\n\n            if finalizer_hooks:\n              if has_saver_hook:\n                saver = _NotSaver(\n                    'No save on shutdown when there are user-defined '\n                    'CheckpointSaverHooks')\n              else:\n                saver = None  # Yes automatic save on shutdown.\n              shutdown_hooks.append(\n                  session_support.GracefulShutdownHook(\n                      checkpoint_prefix=self.model_dir + '/model.ckpt',\n                      on_shutdown_hooks=finalizer_hooks,\n                      saver=saver))\n\n          with tf.control_dependencies([loss]):\n            global_step = tf.identity(tf.compat.v1.train.get_global_step())\n          hooks = input_hooks + shutdown_hooks\n\n          if ctx.feed_hook is not None:\n            tf.compat.v1.logging.info(\n                'Use user implemented tpu infeed outfeed session hook class.')\n            infeed_outfeed_session_hook_class = ctx.feed_hook\n          else:\n            infeed_outfeed_session_hook_class = TPUInfeedOutfeedSessionHook\n\n          hooks.extend([\n              infeed_outfeed_session_hook_class(\n                  ctx,\n                  enqueue_ops,\n                  host_ops,\n                  tpu_compile_op=compile_op,\n                  run_infeed_loop_on_coordinator=(\n                      run_infeed_loop_on_coordinator),\n                  rendezvous=self._rendezvous[mode],\n                  master=self._config.master,\n                  session_config=self._session_config,\n                  tpu_init_ops=tpu_init_ops,\n                  outfeed_every_n_steps=self._config.tpu_config\n                  .experimental_host_call_every_n_steps),\n              InstallSignalHandlerHook()\n          ])\n          if _check_add_preemption_hook(self._config.cluster):\n            hooks.extend(\n                [preempted_hook.CloudTPUPreemptedHook(self._config.cluster)])\n          if (self._log_every_n_steps is not None or\n              self._log_every_n_secs is not None):\n            if self._iterations_per_training_loop.unit == 'count':\n              examples_hook._set_steps_per_run(  # pylint: disable=protected-access\n                  self._iterations_per_training_loop.value)\n            hooks.append(\n                tf.compat.v1.train.LoggingTensorHook(\n                    {\n                        'loss': tf.identity(loss),\n                        'step': global_step,\n                    },\n                    every_n_iter=self._log_every_n_steps,\n                    every_n_secs=self._log_every_n_secs))\n            hooks.append(examples_hook)\n\n          if training_hooks:\n            hooks.extend(training_hooks)\n\n          chief_hooks = []\n          if (not has_saver_hook and\n              (self._config.save_checkpoints_secs or\n               self._config.save_checkpoints_steps)):\n            checkpoint_hook = tf.compat.v1.train.CheckpointSaverHook(\n                self.model_dir,\n                save_secs=self._config.save_checkpoints_secs,\n                save_steps=self._config.save_checkpoints_steps,\n                scaffold=scaffold,\n                save_graph_def=self._config.checkpoint_save_graph_def)\n            if self._iterations_per_training_loop.unit == 'count':\n              checkpoint_hook._set_steps_per_run(  # pylint: disable=protected-access\n                  self._iterations_per_training_loop.value)\n            chief_hooks.append(checkpoint_hook)\n          else:\n            tf.compat.v1.logging.info('Bypassing TPUEstimator hook')\n\n          tf.compat.v1.summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss)\n          with tf.control_dependencies([loss]):\n            update_ops = _sync_variables_ops(ctx)\n            if ctx.embedding_config:\n              update_ops.extend(embedding_variables_and_ops.retrieve_ops())\n\n          # Validate the TPU training graph to catch basic errors\n          _validate_tpu_training_graph(ctx)\n\n          train_op = tf.group(*update_ops)\n          graph.add_to_collection(_TPU_TRAIN_OP, train_op)\n\n          return model_fn_lib.EstimatorSpec(\n              mode,\n              loss=loss,\n              training_chief_hooks=chief_hooks,\n              training_hooks=hooks,\n              train_op=train_op,\n              scaffold=scaffold)\n\n        if mode == model_fn_lib.ModeKeys.EVAL:\n          compile_op, total_loss, host_calls, scaffold_fn, eval_hooks = (\n              _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn))\n          if ctx.embedding_config:\n            g = tf.compat.v1.get_default_graph()\n            table_to_config_dict = (\n                ctx.embedding_config.tpu_embedding.table_to_config_dict)\n            if self._embedding_from_feature_columns:\n              embedding_variable_name_by_table, _ = (\n                  _tpu_estimator_embedding.get_full_variable_names(\n                      g, table_to_config_dict))\n            else:\n              embedding_variable_name_by_table = None\n            embedding_variables_and_ops = (\n                ctx.embedding_config.tpu_embedding.create_variables_and_ops(\n                    embedding_variable_name_by_table))\n            tpu_init_ops.extend(embedding_variables_and_ops.load_ops())\n          # scaffold_fn must be called after variables for TPU embedding has\n          # been created on CPU, as user might reinitialize those from some\n          # checkpoint within scaffold_fn.\n          scaffold = _get_scaffold(scaffold_fn)\n          iterations_per_loop_var = _create_or_get_iterations_per_loop()\n          mean_loss = tf.compat.v1.div(\n              total_loss,\n              tf.cast(iterations_per_loop_var, dtype=total_loss.dtype))\n\n          with tf.control_dependencies([mean_loss]):\n            # After TPU evaluation computation is done (the mean_loss tensor),\n            # reads all variables back from TPU and updates the eval step\n            # counter properly\n            internal_ops_to_run = _sync_variables_ops(ctx)\n            internal_ops_to_run.append(\n                _increase_eval_step_op(iterations_per_loop_var))\n\n          host_call_ret = host_calls.create_tpu_hostcall()\n          eval_metric_ops = {}\n          eval_update_ops = []\n\n          eval_metrics = host_call_ret.get('eval_metrics', {})\n          if eval_metrics:\n            # Creates a dummy metric update_op for all metrics. Estimator\n            # expects all metrics in `eval_metric_ops` have update_op and calls\n            # them one by one. The real metric update_ops are invoked in a\n            # separated thread. So, here give Estimator the dummy op for all\n            # metrics.\n            with tf.control_dependencies(internal_ops_to_run):\n              dummy_update_op = tf.no_op()\n\n            for k, v in eval_metrics.items():\n              eval_metric_ops[k] = (v[0], dummy_update_op)\n              eval_update_ops.append(v[1])\n          else:\n            # If no eval metrics are passed, create an identity node for the\n            # loss and add `internal_ops_to_run` to its dependencies. So\n            # `internal_ops_to_run` can be executed.\n            with tf.control_dependencies(internal_ops_to_run):\n              mean_loss = tf.identity(mean_loss)\n\n          if 'host_call' not in host_call_ret:\n            host_ops = []\n          else:\n            host_ops = host_call_ret['host_call']\n          hooks = [\n              TPUInfeedOutfeedSessionHook(\n                  ctx,\n                  enqueue_ops,\n                  eval_update_ops + host_ops,\n                  tpu_compile_op=compile_op,\n                  run_infeed_loop_on_coordinator=(\n                      run_infeed_loop_on_coordinator),\n                  rendezvous=self._rendezvous[mode],\n                  master=self._config.evaluation_master,\n                  session_config=self._session_config,\n                  tpu_init_ops=tpu_init_ops)\n          ] + input_hooks\n\n          if _check_add_preemption_hook(self._config.cluster):\n            hooks.extend(\n                [preempted_hook.CloudTPUPreemptedHook(self._config.cluster)])\n\n          if eval_hooks:\n            hooks.extend(eval_hooks)\n\n          return model_fn_lib.EstimatorSpec(\n              mode,\n              loss=mean_loss,\n              evaluation_hooks=hooks,\n              eval_metric_ops=eval_metric_ops,\n              scaffold=scaffold)\n\n        # Predict\n        assert mode == model_fn_lib.ModeKeys.PREDICT\n\n        (compile_op, dummy_predict_op, host_calls, scaffold_fn,\n         prediction_hooks) = _predict_on_tpu_system(ctx, model_fn_wrapper,\n                                                    dequeue_fn)\n        scaffold = _get_scaffold(scaffold_fn)\n        with tf.control_dependencies([dummy_predict_op]):\n          internal_ops_to_run = _sync_variables_ops(ctx)\n          with tf.control_dependencies(internal_ops_to_run):\n            dummy_predict_op = tf.no_op()\n\n        # In train and evaluation, the main TPU program is passed to monitored\n        # training session to run. Infeed enqueue and outfeed dequeue are\n        # executed in side threads. This is not the configuration for\n        # prediction mode.\n        #\n        # For prediction, the Estimator executes the EstimatorSpec.predictions\n        # directly and yield the element (via generator) to call site. So, the\n        # outfeed based prediction must be passed to MonitoredSession directly.\n        # Other parts of the TPU execution are organized as follows.\n        #\n        # 1. All outfeed based Tensors must be grouped with predictions Tensors\n        #    to form a single invocation. This avoid the issue we might trigger\n        #    multiple outfeeds incorrectly. To achieve this, `host_call` is\n        #    placed in control_dependencies of `stopping_signals`, and\n        #    `stopping_signals` is passed into _StoppingPredictHook, which sets\n        #    the `stopping_signals` as SessionRunArgs. MonitoredSession merges\n        #    all SessionRunArgs with the fetch in session.run together.\n        #\n        # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue)\n        #    are grouped together. They will be launched once and only once in\n        #    side threads and they quit naturally according to the SAME stopping\n        #    condition.\n        enqueue_ops.append(dummy_predict_op)\n\n        host_call_ret = host_calls.create_tpu_hostcall()\n        if 'host_call' not in host_call_ret:\n          host_ops = []\n        else:\n          host_ops = host_call_ret['host_call']\n\n        predictions = host_call_ret['predictions']\n        _verify_cross_hosts_transfer_size(\n            predictions,\n            message=(\n                'The estimated size for TPUEstimatorSpec.predictions is too '\n                'large.'))\n        signals = host_call_ret['signals']\n\n        with tf.control_dependencies(host_ops):\n          host_ops = []  # Empty, we do do not need it anymore.\n          scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal(\n              signals)\n          predictions = _PaddingSignals.slice_tensor_or_dict(\n              predictions, signals)\n\n        hooks = [\n            _StoppingPredictHook(scalar_stopping_signal),\n            TPUInfeedOutfeedSessionHookForPrediction(\n                ctx,\n                enqueue_ops,\n                host_ops,\n                rendezvous=self._rendezvous[mode],\n                tpu_compile_op=compile_op,\n                master=self._config.master,\n                session_config=self._session_config),\n        ] + input_hooks\n\n        if prediction_hooks:\n          hooks.extend(prediction_hooks)\n\n        return model_fn_lib.EstimatorSpec(\n            mode,\n            prediction_hooks=hooks,\n            predictions=predictions,\n            scaffold=scaffold)\n\n    return _model_fn\n\n\ndef _check_add_preemption_hook(cluster):\n  return (tpu_cluster_resolver.is_running_in_gce() and cluster and isinstance(\n      cluster, tf.distribute.cluster_resolver.TPUClusterResolver) and\n          cluster._cloud_tpu_client.api_available())\n\n\ndef _export_output_to_tensors(export_output):\n  \"\"\"Get a list of `Tensors` used in `export_output`.\n\n  Args:\n    export_output: an `ExportOutput` object such as `ClassificationOutput`,\n      `RegressionOutput`, or `PredictOutput`.\n\n  Returns:\n    a list of tensors used in export_output.\n\n  Raises:\n    ValueError: if `export_output` is not one of `ClassificationOutput`,\n        `RegressionOutput`, or `PredictOutput`.\n  \"\"\"\n  if isinstance(export_output, export_output_lib.ClassificationOutput):\n    return [export_output.scores, export_output.classes]\n  elif isinstance(export_output, export_output_lib.RegressionOutput):\n    return [export_output.value]\n  elif isinstance(export_output, export_output_lib.PredictOutput):\n    return list(export_output.outputs.values())\n  else:\n    raise ValueError(\n        '`export_output` must be have type `ClassificationOutput`, '\n        '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))\n\n\ndef _clone_export_output_with_tensors(export_output, tensors):\n  \"\"\"Clones `export_output` but with new `tensors`.\n\n  Args:\n    export_output: an `ExportOutput` object such as `ClassificationOutput`,\n      `RegressionOutput`, or `PredictOutput`.\n    tensors: a list of `Tensors` used to construct a new `export_output`.\n\n  Returns:\n    A dict similar to `export_output` but with `tensors`.\n\n  Raises:\n    ValueError: if `export_output` is not one of `ClassificationOutput`,\n        `RegressionOutput`, or `PredictOutput`.\n  \"\"\"\n  if isinstance(export_output, export_output_lib.ClassificationOutput):\n    if len(tensors) != 2:\n      raise ValueError('tensors must be of length 2; '\n                       'got {}.'.format(len(tensors)))\n    return export_output_lib.ClassificationOutput(*tensors)\n  elif isinstance(export_output, export_output_lib.RegressionOutput):\n    if len(tensors) != 1:\n      raise ValueError('tensors must be of length 1; '\n                       'got {}'.format(len(tensors)))\n    return export_output_lib.RegressionOutput(*tensors)\n  elif isinstance(export_output, export_output_lib.PredictOutput):\n    return export_output_lib.PredictOutput(\n        dict(zip(export_output.outputs.keys(), tensors)))\n  else:\n    raise ValueError(\n        '`export_output` must be have type `ClassificationOutput`, '\n        '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output))\n\n\ndef _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):\n  \"\"\"Executes `model_fn_wrapper` multiple times on all TPU shards.\"\"\"\n  iterations_per_loop_var = _create_or_get_iterations_per_loop()\n\n  (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks\n  ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)\n\n  @tpu_function.on_device_training_loop\n  def multi_tpu_eval_steps_on_single_shard(replica_id):\n    # `tpu.split_compile_and_shard()` splits and passes input for each\n    # replica as an array. As so, correctly reshape the input to be a\n    # scalar.\n    replica_id = tf.reshape(replica_id, [])\n    with tpu_context._TPUEstimatorReplicaContext(replica_id):  # pylint: disable=protected-access\n      return training_loop.repeat(iterations_per_loop_var, single_tpu_eval_step,\n                                  [_ZERO_LOSS])\n\n  # Add input that represents id for each replica in sync so that\n  # _TPUEstimatorReplicaContext can be correctly entered during\n  # replicated computation.\n  replica_id_inputs = []\n  replica_id_inputs.append([tf.constant(i) for i in range(ctx.num_replicas)])\n\n  (\n      compile_op,\n      loss,\n  ) = tpu.split_compile_and_shard(\n      multi_tpu_eval_steps_on_single_shard,\n      inputs=replica_id_inputs,\n      num_shards=ctx.num_replicas,\n      outputs_from_all_shards=False,\n      device_assignment=ctx.device_assignment)\n\n  loss = loss[0]\n  return (compile_op, loss, host_calls, captured_scaffold_fn,\n          captured_eval_hooks.get())\n\n\ndef _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):\n  \"\"\"Executes `model_fn_wrapper` multiple times on all TPU shards.\"\"\"\n  iterations_per_loop_var = _create_or_get_iterations_per_loop()\n\n  (single_tpu_train_step, host_call, captured_scaffold_fn,\n   captured_training_hooks) = (\n       model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn))\n\n  @tpu_function.on_device_training_loop\n  def multi_tpu_train_steps_on_single_shard(replica_id):\n    # `tpu.split_compile_and_shard()` splits and passes input for each\n    # replica as an array. As so, correctly reshape the input to be a\n    # scalar.\n    replica_id = tf.reshape(replica_id, [])\n    with tpu_context._TPUEstimatorReplicaContext(replica_id):  # pylint: disable=protected-access\n      outputs = training_loop.while_loop(\n          lambda i, loss: i < iterations_per_loop_var,\n          lambda i, loss: [i + 1, single_tpu_train_step(i)],\n          inputs=[0, _INITIAL_LOSS])\n      return outputs[1:]\n\n  # Add input that represents id for each replica in sync so that\n  # _TPUEstimatorReplicaContext can be correctly entered during\n  # replicated computation.\n  replica_id_inputs = []\n  replica_id_inputs.append([tf.constant(i) for i in range(ctx.num_replicas)])\n\n  (compile_op, loss) = tpu.split_compile_and_shard(\n      multi_tpu_train_steps_on_single_shard,\n      inputs=replica_id_inputs,\n      num_shards=ctx.num_replicas,\n      outputs_from_all_shards=False,\n      device_assignment=ctx.device_assignment)\n\n  loss = loss[0]\n  return (compile_op, loss, host_call, captured_scaffold_fn,\n          captured_training_hooks.get())\n\n\ndef _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn):\n  \"\"\"Executes `model_fn_wrapper` multiple times on all TPU shards.\"\"\"\n  (single_tpu_predict_step, host_calls, captured_scaffold_fn,\n   captured_predict_hooks\n  ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn)\n\n  @tpu_function.on_device_training_loop\n  def multi_tpu_predict_steps_on_single_shard(replica_id):\n    # `tpu.split_compile_and_shard()` splits and passes input for each\n    # replica as an array. As so, correctly reshape the input to be a\n    # scalar.\n    replica_id = tf.reshape(replica_id, [])\n    with tpu_context._TPUEstimatorReplicaContext(replica_id):  # pylint: disable=protected-access\n\n      def cond(scalar_stopping_signal):\n        return tf.math.logical_not(\n            _StopSignals.should_stop(scalar_stopping_signal))\n\n      inputs = [_StopSignals.NON_STOPPING_SIGNAL]\n      outputs = training_loop.while_loop(\n          cond, single_tpu_predict_step, inputs=inputs, name=b'loop')\n      return outputs\n\n  # Add input that represents id for each replica in sync so that\n  # _TPUEstimatorReplicaContext can be correctly entered during\n  # replicated computation.\n  replica_id_inputs = []\n  replica_id_inputs.append([tf.constant(i) for i in range(ctx.num_replicas)])\n  (\n      compile_op,\n      dummy_predict_op,\n  ) = tpu.split_compile_and_shard(\n      multi_tpu_predict_steps_on_single_shard,\n      inputs=replica_id_inputs,\n      num_shards=ctx.num_replicas,\n      outputs_from_all_shards=False,\n      device_assignment=ctx.device_assignment)\n\n  dummy_predict_op = dummy_predict_op[0]\n  return (compile_op, dummy_predict_op, host_calls, captured_scaffold_fn,\n          captured_predict_hooks.get())\n\n\ndef _wrap_computation_in_while_loop(device, op_fn):\n  \"\"\"Wraps the ops generated by `op_fn` in tf.while_loop.\"\"\"\n\n  def computation(i):\n    with tf.control_dependencies(op_fn()):\n      return i + 1\n\n  iterations_per_loop_var = _create_or_get_iterations_per_loop()\n  # By setting parallel_iterations=1, the parallel execution in while_loop is\n  # basically turned off.\n  with tf.compat.v1.device(device):\n    iterations = tf.identity(iterations_per_loop_var)\n    return tf.compat.v1.while_loop(\n        lambda i: i < iterations,\n        computation, [tf.constant(0)],\n        parallel_iterations=1)\n\n\ndef _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):\n  \"\"\"Wraps the ops generated by `op_fn` in tf.while_loop.\"\"\"\n\n  def cond(scalar_stopping_signal):\n    return tf.math.logical_not(_StopSignals.should_stop(scalar_stopping_signal))\n\n  def computation(unused_scalar_stopping_signal):\n    return_value = op_fn()\n    execute_ops = return_value['ops']\n    signals = return_value['signals']\n    with tf.control_dependencies(execute_ops):\n      return _StopSignals.as_scalar_stopping_signal(signals)\n\n  # By setting parallel_iterations=1, the parallel execution in while_loop is\n  # basically turned off.\n  with tf.compat.v1.device(device):\n    return tf.compat.v1.while_loop(\n        cond,\n        computation, [_StopSignals.NON_STOPPING_SIGNAL],\n        parallel_iterations=1)\n\n\ndef _validate_tpu_training_graph(ctx):\n  \"\"\"Validate graph before running distributed training.\n\n  Args:\n    ctx: A `_InternalTPUContext` instance with mode.\n\n  Raises:\n    ValueError: If the graph seems invalid for running on device\n  \"\"\"\n  if control_flow_util.ENABLE_CONTROL_FLOW_V2:\n    return  # b/124241278\n\n  operations = tf.compat.v1.get_default_graph().get_operations()\n\n  # Check if there is atleast one CrossReplicaSum operation in the graph\n  # This should be introduced by using the CrossShardOptimizer wrapper\n  cross_replica_sum_ops = [\n      o for o in operations if o.type == _CROSS_REPLICA_SUM_OP\n  ]\n  if not cross_replica_sum_ops and ctx.num_replicas > 1:\n    raise ValueError(\n        'CrossShardOptimizer must be used for model training on TPUs.')\n\n\nclass _CapturedObject(object):\n  \"\"\"A placeholder to capture an object.\n\n  This is useful when we need to capture a Python object in the Tensorflow\n  control flow body function and use it outside the control flow.\n  \"\"\"\n\n  def __init__(self):\n    self._object = None\n    self._captured = False\n\n  def capture(self, o):\n    if self._captured:\n      raise RuntimeError(\n          'InternalError: Object can capture only once. Please file bug.')\n\n    self._captured = True\n    self._object = o\n\n  def get(self):\n    if not self._captured:\n      raise RuntimeError(\n          'InternalError: Object is not captured properly before `get`. '\n          'Please file bug.')\n    return self._object\n\n\ndef _get_scaffold(captured_scaffold_fn):\n  \"\"\"Retrieves the Scaffold from `captured_scaffold_fn`.\"\"\"\n  with _CapturingContext(message='Inside scaffold_fn'):\n    scaffold_fn = captured_scaffold_fn.get()\n    if scaffold_fn:\n      scaffold = scaffold_fn()\n      if scaffold is None:\n        raise ValueError(\n            'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')\n    else:\n      scaffold = None\n\n  if scaffold:\n    wrapped_finalize = scaffold.finalize\n\n    def _finalize():\n      with _CapturingContext('Inside Scaffold.finalize'):\n        wrapped_finalize()\n\n    scaffold.finalize = _finalize\n  return scaffold\n\n\nclass _CapturingContext(control_flow_ops.ControlFlowContext):\n  \"\"\"Tracks references to Tensors defined in TPU replication.\"\"\"\n\n  def __init__(self, message):\n    control_flow_ops.ControlFlowContext.__init__(self)\n    self._message = message\n\n  def to_control_flow_context_def(self, context_def, export_scope=None):\n    # pylint: disable=useless-super-delegation\n    # NOTE(slebedev): the method is required by `ControlFlowContext`.\n    super(_CapturingContext,\n          self).to_control_flow_context_def(context_def, export_scope)\n\n  def AddOp(self, op):  # pylint: disable=invalid-name\n    for c in op.inputs:\n      if tpu_replication._TPU_REPLICATE_ATTR in c.op.node_def.attr:  # pylint: disable=protected-access\n        raise ValueError('{}: Op {} depends on TPU computation {}, '\n                         'which is not allowed.'.format(self._message, op, c))\n\n  def AddValue(self, value):\n    self.AddOp(value.op)\n    return value\n\n  def __enter__(self):\n    # pylint: disable=protected-access\n    self._g = tf.compat.v1.get_default_graph()\n    self._old = self._g._get_control_flow_context()\n    self._g._set_control_flow_context(self)\n    # pylint: enable=protected-access\n\n  def __exit__(self, _, __, ___):  # pylint: disable=invalid-name\n    self._g._set_control_flow_context(self._old)  # pylint: disable=protected-access\n\n\nclass _Inputs(object):\n  \"\"\"A data structure representing the input_fn returned values.\n\n  This also supports the returned value from input_fn as `Dataset`.\n  \"\"\"\n\n  def __init__(self, features=None, labels=None, dataset=None, signals=None):\n    if dataset is not None and (features is not None or labels is not None or\n                                signals is not None):\n      raise RuntimeError('Internal Error: Either (features and labels) or '\n                         'dataset should be provided, not both. Please file '\n                         'bug')\n\n    self._features = features\n    self._labels = labels\n    self._signals = signals\n\n    self._dataset = dataset\n    self._iterator = None\n\n  @staticmethod\n  def from_input_fn(return_values):\n    \"\"\"Returns an `_Inputs` instance according to `input_fn` return value.\"\"\"\n    if isinstance(return_values, tf.compat.v2.data.Dataset):\n      dataset = return_values\n      return _Inputs(dataset=dataset)\n\n    features, labels = _Inputs._parse_inputs(return_values)\n    return _Inputs(features, labels)\n\n  @staticmethod\n  def _parse_inputs(return_values):\n    if isinstance(return_values, tuple):\n      features, labels = return_values\n    else:\n      features, labels = return_values, None\n    return features, labels\n\n  @property\n  def is_dataset(self):\n    \"\"\"Returns True if the return value from input_fn is Dataset.\"\"\"\n    return self._dataset is not None\n\n  def dataset_initializer(self):\n    \"\"\"Returns the dataset's initializer.\n\n    The initializer must be run before calling `features_and_labels`.\n    \"\"\"\n    self._iterator = tf.compat.v1.data.make_initializable_iterator(\n        self._dataset)\n    return self._iterator.initializer\n\n  def features_and_labels(self):\n    \"\"\"Gets `features` and `labels`.\"\"\"\n    if self.is_dataset:\n      if self._iterator is None:\n        raise RuntimeError('Internal error: Must run dataset_initializer '\n                           'before calling features_and_labels(). Please file '\n                           'a bug!')\n      return _Inputs._parse_inputs(self._iterator.get_next())\n\n    return (self._features, self._labels)\n\n  def signals(self):\n    return self._signals\n\n  @property\n  def dataset(self):\n    return self._dataset\n\n\nclass _InputsWithStoppingSignals(_Inputs):\n  \"\"\"Inputs with `_StopSignals` inserted into the dataset.\"\"\"\n\n  def __init__(self,\n               dataset,\n               batch_size,\n               add_padding=False,\n               num_invocations_per_step=1):\n\n    assert dataset is not None\n    user_provided_dataset = dataset.map(\n        _InputsWithStoppingSignals.insert_stopping_signal(\n            stop=False, batch_size=batch_size, add_padding=add_padding))\n    if num_invocations_per_step == 1:\n      final_batch_dataset = dataset.take(1).map(\n          _InputsWithStoppingSignals.insert_stopping_signal(\n              stop=True, batch_size=batch_size, add_padding=add_padding))\n    else:\n      # We append (2 * num_invocations_per_step - 1) batches for exhausting the\n      # user_provided_dataset and stop properly.\n      # For example, if num_invocations_per_step is 2, we append 3 additional\n      # padding batches: b1, b2, b3.\n      # If user_provided_dataset contains two batches: a1, a2\n      # Step 1: [a1, a2]\n      # Step 2: [b1, b2] -> STOP\n      # If user_provided_dataset contains three batches: a1, a2, a3.\n      # The training loops:\n      # Step 1: [a1, a2]\n      # Step 2: [a3, b1]\n      # Step 3: [b2, b3] -> STOP.\n      final_batch_dataset = dataset.take(1).map(\n          _InputsWithStoppingSignals.insert_stopping_signal(\n              stop=True, batch_size=batch_size, add_padding=add_padding))\n      final_batch_dataset = final_batch_dataset.repeat(\n          2 * num_invocations_per_step - 1)\n\n      def _set_mask(data_dict):\n        signals = data_dict['signals']\n        signals['padding_mask'] = tf.compat.v1.ones_like(\n            signals['padding_mask'])\n        data_dict['signals'] = signals\n        return data_dict\n\n      # Mask out the extra batch.\n      final_batch_dataset = final_batch_dataset.map(_set_mask)\n\n    dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2)\n\n    super(_InputsWithStoppingSignals, self).__init__(dataset=dataset)\n    self._current_inputs = None\n\n  def features_and_labels(self):\n    if self._current_inputs is not None:\n      raise RuntimeError(\n          'Internal Error: The previous inputs have not been properly '\n          'consumed. First call features_and_labels, then call signals.')\n\n    inputs_with_signals = self._iterator.get_next()\n    features = inputs_with_signals['features']\n    labels = inputs_with_signals.get('labels')\n\n    self._current_inputs = inputs_with_signals\n    return features, labels\n\n  def signals(self):\n    \"\"\"Returns the `Signals` from `_Inputs`.\"\"\"\n    if self._current_inputs is None:\n      raise RuntimeError(\n          'Internal Error: The current inputs have not been properly '\n          'generated. First call features_and_labels, then call signals.')\n    signals = self._current_inputs['signals']\n    self._current_inputs = None\n    return signals\n\n  @staticmethod\n  def insert_stopping_signal(stop, batch_size, add_padding=False):\n    \"\"\"Inserts stopping_signal into dataset via _map_fn.\n\n    Here we change the data structure in the dataset, such that the return value\n    is a dictionary now and `features`, `labels`, and `signals` are three\n    distinguished keys in that dict. This provides a better structure, which\n    eases the process to decompose the inputs (see `features_and_labels`).\n\n    Args:\n      stop: bool, state of current stopping signals.\n      batch_size: int, batch size.\n      add_padding: bool, whether to pad the tensor to full batch size.\n\n    Returns:\n      A map_fn passed to dataset.map API.\n    \"\"\"\n\n    def _map_fn(*args):\n      \"\"\"The map fn to insert signals.\"\"\"\n      if len(args) == 1:\n        # Unpack the single Tensor/dict argument as features. This is required\n        # for the input_fn returns no labels.\n        args = args[0]\n      features, labels = _Inputs._parse_inputs(args)\n      new_input_dict = {}\n\n      if add_padding:\n        padding_mask, features, labels = (\n            _PaddingSignals.pad_features_and_labels(features, labels,\n                                                    batch_size))\n\n        new_input_dict['features'] = features\n        if labels is not None:\n          new_input_dict['labels'] = labels\n\n      else:\n        new_input_dict['features'] = features\n        if labels is not None:\n          new_input_dict['labels'] = labels\n        padding_mask = None\n\n      new_input_dict['signals'] = _StopSignals(\n          stop=stop, batch_size=batch_size,\n          padding_mask=padding_mask).as_dict()\n\n      return new_input_dict\n\n    return _map_fn\n\n\nclass _StopSignals(object):\n  \"\"\"Signals class holding all logic to handle TPU stopping condition.\"\"\"\n\n  NON_STOPPING_SIGNAL = False\n  STOPPING_SIGNAL = True\n\n  def __init__(self, stop, batch_size, padding_mask=None):\n    self._stop = stop\n    self._batch_size = batch_size\n    self._padding_mask = padding_mask\n\n  def as_dict(self):\n    \"\"\"Returns the signals as Python dict.\"\"\"\n    shape = [self._batch_size, 1]\n    dtype = tf.dtypes.bool\n\n    if self._stop:\n      stopping = tf.ones(shape=shape, dtype=dtype)\n    else:\n      stopping = tf.zeros(shape=shape, dtype=dtype)\n\n    signals = {'stopping': stopping}\n    if self._padding_mask is not None:\n      signals['padding_mask'] = self._padding_mask\n    return signals\n\n  @staticmethod\n  def as_scalar_stopping_signal(signals):\n    return tf.identity(signals['stopping'][0][0])\n\n  @staticmethod\n  def should_stop(scalar_stopping_signal):\n    \"\"\"Detects whether scalar_stopping_signal indicates stopping.\"\"\"\n    if isinstance(scalar_stopping_signal, tf.Tensor):\n      # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF\n      # way to express the bool check whether scalar_stopping_signal is True.\n      return tf.math.logical_and(scalar_stopping_signal,\n                                 _StopSignals.STOPPING_SIGNAL)\n    else:\n      # For non Tensor case, it is used in SessionRunHook. So, we cannot modify\n      # the graph anymore. Here, we use pure Python.\n      return bool(scalar_stopping_signal)\n\n\nclass _PaddingSignals(object):\n  \"\"\"Signals class holding all logic to handle padding.\"\"\"\n\n  @staticmethod\n  def pad_features_and_labels(features, labels, batch_size):\n    \"\"\"Pads out the batch dimension of features and labels.\"\"\"\n    real_batch_size = tf.compat.v1.shape(\n        _PaddingSignals._find_any_tensor(features))[0]\n\n    batch_size_tensor = tf.constant(batch_size, tf.dtypes.int32)\n\n    check_greater = tf.compat.v1.debugging.assert_greater_equal(\n        batch_size_tensor,\n        real_batch_size,\n        data=(batch_size_tensor, real_batch_size),\n        message='The real batch size should not be greater than batch_size.')\n\n    with tf.control_dependencies([check_greater]):\n      missing_count = batch_size_tensor - real_batch_size\n\n    def pad_single_tensor(tensor):\n      \"\"\"Pads out the batch dimension of a tensor to the complete batch_size.\"\"\"\n      rank = len(tensor.shape)\n      assert rank > 0\n      padding = tf.stack([[0, missing_count]] + [[0, 0]] * (rank - 1))\n      padded_shape = (batch_size,) + tuple(tensor.shape[1:])\n      padded_tensor = tf.compat.v1.pad(tensor, padding)\n      padded_tensor.set_shape(padded_shape)\n      return padded_tensor\n\n    def nest_pad(tensor_or_dict):\n      return tf.nest.map_structure(pad_single_tensor, tensor_or_dict)\n\n    features = nest_pad(features)\n    if labels is not None:\n      labels = nest_pad(labels)\n\n    padding_mask = _PaddingSignals._padding_mask(real_batch_size, missing_count,\n                                                 batch_size)\n\n    return padding_mask, features, labels\n\n  @staticmethod\n  def slice_tensor_or_dict(tensor_or_dict, signals):\n    \"\"\"Slice the real Tensors according to padding mask in signals.\"\"\"\n\n    padding_mask = signals['padding_mask']\n    batch_size = tf.compat.v1.shape(padding_mask)[0]\n\n    def verify_batch_size(tensor):\n      check_batch_size = tf.math.equal(batch_size, tensor.shape[0])\n      with tf.control_dependencies([check_batch_size]):\n        return tf.identity(tensor)\n\n    def slice_single_tensor(tensor):\n      rank = len(tensor.shape)\n      assert rank > 0\n      real_batch_size = batch_size - tf.math.reduce_sum(padding_mask)\n      return verify_batch_size(tensor)[0:real_batch_size]\n\n    # As we split the Tensors to all TPU cores and concat them back, it is\n    # important to ensure the real data is placed before padded ones, i.e.,\n    # order is preserved. By that, the sliced padding mask should have all 0's.\n    # If this assertion failed, # the slice logic here would not hold.\n    sliced_padding_mask = slice_single_tensor(padding_mask)\n    assert_padding_mask = tf.math.equal(\n        tf.math.reduce_sum(sliced_padding_mask), 0)\n\n    with tf.control_dependencies([assert_padding_mask]):\n      should_stop = _StopSignals.should_stop(\n          _StopSignals.as_scalar_stopping_signal(signals))\n\n    is_full_batch = tf.math.equal(tf.math.reduce_sum(padding_mask), 0)\n\n    def slice_fn(tensor):\n      # If the current batch is full batch or part of stopping signals, we do\n      # not need to slice to save performance.\n      return tf.compat.v1.cond(\n          tf.math.logical_or(should_stop, is_full_batch),\n          (lambda: verify_batch_size(tensor)),\n          (lambda: slice_single_tensor(tensor)))\n\n    return tf.nest.map_structure(slice_fn, tensor_or_dict)\n\n  @staticmethod\n  def _find_any_tensor(batch_features):\n    tensors = [\n        x for x in tf.nest.flatten(batch_features) if isinstance(x, tf.Tensor)\n    ]\n    if not tensors:\n      raise ValueError('Cannot find any Tensor in features dict.')\n    return tensors[0]\n\n  @staticmethod\n  def _padding_mask(real_batch_size, missing_count, batch_size):\n    padding_mask = tf.concat([\n        tf.zeros((real_batch_size,), dtype=tf.dtypes.int32),\n        tf.ones((missing_count,), dtype=tf.dtypes.int32)\n    ],\n                             axis=0)\n    padding_mask.set_shape((batch_size,))\n    return padding_mask\n\n\ndef _verify_cross_hosts_transfer_size(tensor_dict, message):\n  total_size = 0\n  tensor_structure = {}\n  for key, tensor in tensor_dict.items():\n    shape = tensor.shape\n    size = np.prod(shape) * tensor.dtype.size\n    tensor_structure[key] = shape\n    total_size += size\n  if total_size >= _ONE_GIGABYTE:\n    raise ValueError(\n        '{} The transfer size is larger than the protobuf limit. Please '\n        'consider to use Tensors with smaller shapes or reduce batch '\n        'size. Given:\\n'\n        '{}'.format(\n            message, '\\n'.join([\n                ' -- Key: {}, Shape: {}'.format(k, v)\n                for k, v in tensor_structure.items()\n            ])))\n\n\ndef _add_item_to_params(params, key, value):\n  \"\"\"Adds a new item into `params`.\"\"\"\n  if hasattr(params, 'set_hparam'):\n    # For HParams, we need to use special API.\n    if key in params:\n      params.set_hparam(key, value)\n    else:\n      params.add_hparam(key, value)\n  else:\n    # Now params is Python dict.\n    params[key] = value\n\n\ndef export_estimator_savedmodel(estimator,\n                                export_dir_base,\n                                serving_input_receiver_fn,\n                                assets_extra=None,\n                                as_text=False,\n                                checkpoint_path=None):\n  \"\"\"Export `Estimator` trained model for TPU inference.\n\n  Args:\n    estimator: `Estimator` with which model has been trained.\n    export_dir_base: A string containing a directory in which to create\n      timestamped subdirectories containing exported SavedModels.\n    serving_input_receiver_fn: A function that takes no argument and returns a\n      `ServingInputReceiver` or `TensorServingInputReceiver`.\n    assets_extra: A dict specifying how to populate the assets.extra directory\n      within the exported SavedModel, or `None` if no extra assets are needed.\n    as_text: whether to write the SavedModel proto in text format.\n    checkpoint_path: The checkpoint path to export.  If `None` (the default),\n      the most recent checkpoint found within the model directory is chosen.\n\n  Returns:\n    The string path to the exported directory.\n  \"\"\"\n  # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use\n  # `estimator.config`.\n  config = tpu_config.RunConfig(model_dir=estimator.model_dir)\n  est = TPUEstimator(\n      estimator._model_fn,  # pylint: disable=protected-access\n      config=config,\n      params=estimator.params,\n      use_tpu=True,\n      train_batch_size=2048,  # Does not matter.\n      eval_batch_size=2048,  # Does not matter.\n  )\n  return est.export_saved_model(export_dir_base, serving_input_receiver_fn,\n                                assets_extra, as_text, checkpoint_path)\n\n\ndef model_fn_inference_on_tpu(model_fn,\n                              features,\n                              labels=None,\n                              config=None,\n                              params=None,\n                              batch_config=None):\n  \"\"\"Convenience wrapper for export_saved_model API v2 for a model_fn.\n  WARNING:THIS METHOD IS DEPRECATED AND NOT PART OF THE APIS.\n\n  Make sure to set\n  `export_saved_model_api_version=tpu_estimator.ExportSavedModelApiVersion.V2`\n  when initializing TPUEstimator (default API version is V1). This is because\n  1) `tpu.rewrite` (or `tpu.compile`) shouldn't be called in a nested way\n      (otherwise validation will throw error like\n      \"NotImplementedError: tpu_shard_context cannot be nested.\")\n  2) When using V1 API, Estimator calls `tpu.rewrite` so\n     using `model_fn_inference_on_tpu` will trigger a nested call.\n     When using V2 API, users of Estimator needs to call `tpu.rewrite` (which\n     the wrapper does).\n\n  It attempts to execute the entire model function on the TPU for prediction.\n  Note that this does not support features which are SparseTensors. If you have\n  SparseTensor features, consider partitioning your model function further and\n  use inference_on_tpu.\n\n  Args:\n    model_fn: the model_fn for which we want to inference on TPU.\n    features: a tensor or dict of tensors, serves as the feature inputs to the\n      model.\n    labels: a tensor or dict of tensors, serves as the labels inputs to the\n      model.\n    config: auxiliary config to the Estimator.\n    params: hparams that we want to pass to the model_fn.\n    batch_config: a named tuple to wrap the inference batching configuration\n      inputs.\n\n  Returns:\n    An EstimatorSpec containing the outputs in export_outputs and predictions.\n  \"\"\"\n  computation, capture = _build_computation_for_inference(\n      model_fn, labels, config, params)\n  tensors = call_computation(features, computation, batch_config=batch_config)\n  estimator_spec, export_outputs_dict, predictions_dict, none_indices = (\n      capture.get())\n  predictions_list = tensors[:len(predictions_dict)]\n  export_outputs_list_without_none = tensors[len(predictions_dict):]\n\n  # Reinsert `None`s which we've taken out in\n  # `_build_computation_for_inference()`.\n  export_outputs_list = []\n  while none_indices or export_outputs_list_without_none:\n    if none_indices and none_indices[0] == len(export_outputs_list):\n      export_outputs_list.append(None)\n      none_indices.pop(0)\n    else:\n      export_outputs_list.append(export_outputs_list_without_none.pop(0))\n\n  # Reconstruct `export_outputs` with updated tensors.\n  new_export_outputs_dict = tf.nest.pack_sequence_as(export_outputs_dict,\n                                                     export_outputs_list)\n  export_outputs = estimator_spec.export_outputs\n  new_export_outputs = collections.OrderedDict(\n      (k, _clone_export_output_with_tensors(export_outputs[k], v))\n      for k, v in six.iteritems(new_export_outputs_dict))\n  # Reconstruct `predictions` with updated tensors.\n  new_predictions = tf.nest.pack_sequence_as(predictions_dict, predictions_list)\n  if (len(new_predictions) == 1 and\n      _KEY_WHEN_PREDICTIONS_IS_A_TENSOR in new_predictions):\n    new_predictions = new_predictions[_KEY_WHEN_PREDICTIONS_IS_A_TENSOR]\n\n  return estimator_spec._replace(\n      export_outputs=new_export_outputs, predictions=new_predictions)\n\n\ndef _build_computation_for_inference(model_fn, labels, config, params):\n  \"\"\"Builds the computation with calls the model_fn for inference.\"\"\"\n  capture = _CapturedObject()\n\n  def computation(computation_input):\n    \"\"\"Computation to be passed to `TPUPartitionedCall()`.\"\"\"\n    tpu_computation, tpu_capture = _build_tpu_computation_for_inference(\n        model_fn, computation_input, labels, config, params)\n\n    tensors_on_cpu = tf.compat.v1.tpu.rewrite(tpu_computation)\n    tpu.prune_unconnected_ops_from_xla(tf.compat.v1.get_default_graph())\n\n    (estimator_spec, export_outputs_dict, export_outputs_list,\n     predictions_dict) = (\n         tpu_capture.get())\n    predictions_list = tensors_on_cpu[:len(predictions_dict)]\n    export_outputs_tpu_on_cpu_list = tensors_on_cpu[len(predictions_dict):]\n\n    # Reconstruct tensors used in export_outputs, with TPU tensors replaced\n    # with their CPU counterpart returned from `rewrite_for_inference()`.\n    # `function.Defun()` does not like `None`s in return values, so we leave\n    # `None`s out but record their positions for later reconstruction.\n    export_outputs_list_without_none = []\n    none_indices = []\n    for i, t in enumerate(export_outputs_list):\n      if t is None:\n        none_indices.append(i)\n      else:\n        export_outputs_list_without_none.append(\n            export_outputs_tpu_on_cpu_list.pop(0))\n\n    capture.capture(\n        (estimator_spec, export_outputs_dict, predictions_dict, none_indices))\n    return predictions_list + export_outputs_list_without_none\n\n  return computation, capture\n\n\ndef _build_tpu_computation_for_inference(model_fn, features, labels, config,\n                                         params):\n  \"\"\"Builds the TPU computation for inference on TPU.\"\"\"\n  capture = _CapturedObject()\n\n  def computation():\n    \"\"\"Compute tpu tensors used in export_outputs.\n\n    Passed to rewrite_for_inference so that model_fn will be called under\n    the rewriting contexts. Only tpu tensors are returned, but export_outputs\n    and scaffold are captured.\n\n    Returns:\n       A list of Tensors used in export_outputs and not marked for\n       outside_compilation.\n    \"\"\"\n    # We should only call model fn once and it should be inside `computation`\n    # so that building the graph will happen under `rewrite_for_inference`.\n\n    model_fn_args = function_utils.fn_args(model_fn)\n    kwargs = {}\n    # Makes deep copy with `config` and params` in case user mutates them.\n    if 'labels' in model_fn_args:\n      kwargs['labels'] = labels\n    if 'mode' in model_fn_args:\n      kwargs['mode'] = model_fn_lib.ModeKeys.PREDICT\n    if 'config' in model_fn_args:\n      kwargs['config'] = config\n    if 'params' in model_fn_args:\n      kwargs['params'] = params\n    estimator_spec = model_fn(features, **kwargs)\n\n    # We pick the TPU tensors out from `export_output` and later return them\n    # from `computation` for rewriting.\n    export_outputs_dict = collections.OrderedDict(\n        (k, _export_output_to_tensors(v))\n        for k, v in six.iteritems(estimator_spec.export_outputs))\n    export_outputs_list = tf.nest.flatten(export_outputs_dict)\n    export_outputs_tpu_list = [t for t in export_outputs_list if t is not None]\n\n    if isinstance(estimator_spec.predictions, dict):\n      predictions_dict = collections.OrderedDict(\n          (k, v) for k, v in six.iteritems(estimator_spec.predictions))\n    else:\n      predictions_dict = {\n          _KEY_WHEN_PREDICTIONS_IS_A_TENSOR: estimator_spec.predictions\n      }\n    predictions_list = tf.nest.flatten(predictions_dict)\n\n    # We cannot return everything we want through the return values, so\n    # capture the rest here for later use.\n    capture.capture((estimator_spec, export_outputs_dict, export_outputs_list,\n                     predictions_dict))\n    return predictions_list + export_outputs_tpu_list\n\n  return computation, capture\n\n\ndef inference_on_tpu(computation,\n                     inputs_to_tpu,\n                     num_batch_threads,\n                     max_batch_size,\n                     batch_timeout_micros,\n                     allowed_batch_sizes=None,\n                     max_enqueued_batches=100):\n  \"\"\"Convenient wrapper for export_saved_model API v2 to wrap TPU computation.\n\n  WARNING: THIS METHOD IS DEPRECATED AND NOT PART OF THE APIS.\n\n  Make sure to set\n  `export_saved_model_api_version=tpu_estimator.ExportSavedModelApiVersion.V2`\n  when initializing TPUEstimator (default API version is V1). This is because\n  1) `tpu.rewrite` (or `tpu.compile`) shouldn't be called in a nested way\n      (otherwise validation will throw error like\n      \"NotImplementedError: tpu_shard_context cannot be nested.\")\n  2) When using V1 API, Estimator calls `tpu.rewrite` so\n     using `model_fn_inference_on_tpu` will trigger a nested call.\n     When using V2 API, users of Estimator needs to call `tpu.rewrite` (which\n     the wrapper does).\n\n  It puts computation on TPU, add batching around it and round robin computation\n  between TPU cores.\n\n  See tpu_estimator_test.py for an example.\n\n  Args:\n    computation: computation to be put on TPU, which takes inputs_to_tpu as\n      arguments.\n    inputs_to_tpu: a list of tensors as input to computation.\n    num_batch_threads: Number of scheduling threads for processing batches of\n      work. Determines the number of batches processed in parallel.\n    max_batch_size: Batch sizes will never be bigger than this. If None or 0,\n      no batching will done.\n    batch_timeout_micros: Maximum number of microseconds to wait before\n      outputting an incomplete batch.\n    allowed_batch_sizes: Optional list of allowed batch sizes. If left empty,\n      does nothing. Otherwise, supplies a list of batch sizes, causing the op to\n      pad batches up to one of those sizes. The entries must increase\n      monotonically, and the final entry must equal max_batch_size.\n    max_enqueued_batches: The maximum depth of the batch queue. Defaults to 100.\n\n  Returns:\n    The unbatched computation output Tensors.\n  \"\"\"\n\n  def _tpu_call(args):\n    \"\"\"Function to either call or feed into BatchFunction.\"\"\"\n\n    @function.Defun(capture_resource_var_by_value=False)\n    def tpu_computation():\n      \"\"\"Function to feed into the TPUPartitionedCallOp.\"\"\"\n      tensors_on_cpu = tf.compat.v1.tpu.rewrite(computation, args)\n      tpu.prune_unconnected_ops_from_xla(tf.compat.v1.get_default_graph())\n      return tensors_on_cpu\n\n    return tpu_functional.TPUPartitionedCall(\n        args=tpu_computation.captured_inputs,\n        device_ordinal=tpu_ops.tpu_ordinal_selector(),\n        Tout=[o.type for o in tpu_computation.definition.signature.output_arg],\n        f=tpu_computation)\n\n  if not max_batch_size:\n    return _tpu_call(inputs_to_tpu)\n\n  @tf.nondifferentiable_batch_function(num_batch_threads, max_batch_size,\n                                       batch_timeout_micros,\n                                       allowed_batch_sizes,\n                                       max_enqueued_batches)\n  def batched_tpu_computation(*args):\n    \"\"\"Function to feed into the BatchOp.\"\"\"\n    return _tpu_call(args)\n\n  return batched_tpu_computation(*inputs_to_tpu)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator_embedding_test.py",
    "content": "# Copyright 2021 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for TPUEstimator.\"\"\"\n\nimport itertools\nimport os\nimport tempfile\nfrom absl import flags\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.tpu import feature_column_v2 as tpu_fc_v2\nfrom tensorflow.python.tpu import tpu_embedding\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.export import export_output as export_output_lib\nfrom tensorflow_estimator.python.estimator.tpu import _tpu_estimator_embedding\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\n\nflags.DEFINE_integer('test_num_shards', 8, 'number of replicas to test')\n\nFLAGS = flags.FLAGS\n\n_TRAIN = model_fn_lib.ModeKeys.TRAIN\n_EVAL = model_fn_lib.ModeKeys.EVAL\n_PREDICT = model_fn_lib.ModeKeys.PREDICT\n_PER_HOST_V1 = tpu_config.InputPipelineConfig.PER_HOST_V1\n_PER_HOST_V2 = tpu_config.InputPipelineConfig.PER_HOST_V2\n\n# Constant used for tests that uses categorical_column with vocabulary\n_VOCAB_EMBEDDING_DIM = 10\n_VOCAB_SIZE = 4\n_VOCAB_NUM_BUCKETS = 5\n\n\ndef dense_computation(features):\n  return tf_keras_v1.__internal__.legacy.layers.dense(\n      features['x'], 1, kernel_initializer=tf.compat.v1.zeros_initializer())\n\n\ndef create_run_config(iterations_per_loop, **kwargs):\n  return tpu_config.RunConfig(\n      master='',\n      tpu_config=tpu_config.TPUConfig(\n          iterations_per_loop=iterations_per_loop,\n          num_shards=FLAGS.test_num_shards,\n          **kwargs),\n  )\n\n\nclass TPUEstimatorFeatureColumnTestBase(tf.test.TestCase):\n\n  def setUp(self):\n    self._old_value = tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP\n\n    feature_spec = {\n        'x': tf.io.SparseFeature(['ix0', 'ix1'], 'val',\n                                       tf.int64, [1, 100]),\n        'y': tf.io.SparseFeature(['ix0', 'ix1'], 'val',\n                                       tf.int64, [1, 100])}\n    self._serving_input_receiver_fn = (\n        export.build_parsing_serving_input_receiver_fn(feature_spec))\n    super().setUp()\n\n  def tearDown(self):\n    tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP = self._old_value\n    super().tearDown()\n\n  def _create_estimator_with_feature_columns(self,\n                                             feature_columns,\n                                             numeric_check=False,\n                                             use_cpu=False,\n                                             input_method=_PER_HOST_V2):\n    \"\"\"Returns TPUEstimator which uses `feature_columns` in model_fn.\"\"\"\n\n    def _model_fn(features, labels, mode, params):\n      \"\"\"Creates simple TF model using feature_columns to create input layer.\"\"\"\n      del params\n      sequence_columns, non_sequence_columns = (\n          tpu_fc_v2.split_sequence_columns_v2(feature_columns))\n      if sequence_columns:\n        sequence_layer = tf_keras.experimental.SequenceFeatures(sequence_columns)\n        sequence_features, sequence_lengths = sequence_layer(features)\n        sequence_lengths = tf.dtypes.cast(sequence_lengths, tf.float32)\n      if non_sequence_columns:\n        dense_layer = tf_keras_v1.layers.DenseFeatures(non_sequence_columns)\n        input_layer = dense_layer(features)\n      if numeric_check:\n        # Make predictions the same as input_layer. This is used in some tests\n        # where we set the labels to be the same as input_layer, which forces\n        # the loss to be zero.\n        if sequence_columns:\n          # For sequence columns, we return the sequence lengths, so that we can\n          # verify that these have been correctly calculated.\n          predictions = tf.concat(sequence_lengths, -1)\n        else:\n          predictions = tf.identity(input_layer)\n      else:\n        if sequence_columns:\n          # At this point we know that all the sequence features have the same\n          # max sequence length. To get the total number of entries, so we can\n          # reshape, we need total embedding dimension * max_sequence_length\n          sequence_entries_per_batch = (\n              sequence_features.shape[-1] *\n              sequence_columns[0].get_max_sequence_length())\n          flattened = tf.reshape(\n              sequence_features, [-1, sequence_entries_per_batch])\n          sequence_lengths = tf.compat.v1.expand_dims(sequence_lengths, -1)\n          input_layer = tf.concat(\n              [input_layer, flattened, sequence_lengths], -1)\n        predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n            input_layer, 1, kernel_initializer=tf.compat.v1.zeros_initializer())\n\n      loss = None\n      train_op = None\n      eval_metrics = None\n      export_outputs = None\n      if mode == model_fn_lib.ModeKeys.TRAIN:\n        loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n        optimizer = tf.compat.v1.train.AdagradOptimizer(learning_rate=0.5)\n        optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)\n        train_op = optimizer.minimize(\n            loss, global_step=tf.compat.v1.train.get_global_step())\n      elif mode == model_fn_lib.ModeKeys.EVAL:\n        loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n\n        def metric_fn_on_cpu(labels, predictions):\n          return {\n              'mse': tf.compat.v1.metrics.mean_absolute_error(labels, predictions),\n          }\n\n        eval_metrics = (metric_fn_on_cpu, [labels, predictions])\n\n      else:\n        export_outputs = {'prediction':\n                          export_output_lib.PredictOutput(\n                              {'prediction': predictions})}\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode,\n          train_op=train_op,\n          loss=loss,\n          eval_metrics=eval_metrics,\n          export_outputs=export_outputs,\n          predictions=predictions)\n\n    run_config = create_run_config(\n        iterations_per_loop=2,\n        per_host_input_for_training=input_method)\n    embedding_config_spec = tpu_estimator.EmbeddingConfigSpec(\n        feature_columns=feature_columns,\n        optimization_parameters=tpu_estimator.AdagradParameters(\n            learning_rate=.01, initial_accumulator=0.1),\n    )\n    return tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=8,\n        eval_batch_size=8,\n        use_tpu=not use_cpu,\n        embedding_config_spec=embedding_config_spec,\n        export_to_tpu=True)\n\n\nclass TPUEstimatorFeatureColumnTest(TPUEstimatorFeatureColumnTestBase,\n                                    parameterized.TestCase):\n\n  def test_get_tpu_embedding_config_from_feature_columns(self):\n    feature_a = 'a'\n    feature_b = 'b'  # shared\n    feature_c = 'c'  # shared, weighted\n    feature_d = 'd'  # weighted\n    feature_e = 'e'  # sequence\n    feature_f = 'f'  # shared sequence\n    feature_g = 'g'  # shared sequence\n    feature_h = 'h'  # shared non-sequence\n\n    categorical_column_a = tf.feature_column.categorical_column_with_identity(\n        key=feature_a, num_buckets=3)\n    categorical_column_b = tf.feature_column.categorical_column_with_identity(\n        key=feature_b, num_buckets=6)\n    categorical_column_c = tf.feature_column.categorical_column_with_identity(\n        key=feature_c, num_buckets=6)\n    weight_feature_key_c = 'c_weight'\n    weighted_column_c = tf.feature_column.weighted_categorical_column(\n        categorical_column=categorical_column_c,\n        weight_feature_key=weight_feature_key_c)\n    categorical_column_d = tf.feature_column.categorical_column_with_identity(\n        key=feature_d, num_buckets=3)\n    weight_feature_key_d = 'd_weight'\n    weighted_column_d = tf.feature_column.weighted_categorical_column(\n        categorical_column=categorical_column_d,\n        weight_feature_key=weight_feature_key_d)\n    sequence_categorical_column_e = (\n        tf.feature_column.sequence_categorical_column_with_identity(\n            key=feature_e, num_buckets=7))\n    sequence_categorical_column_f = (\n        tf.feature_column.sequence_categorical_column_with_identity(\n            key=feature_f, num_buckets=4))\n    sequence_categorical_column_g = (\n        tf.feature_column.sequence_categorical_column_with_identity(\n            key=feature_g, num_buckets=4))\n    categorical_column_h = (\n        tf.feature_column.categorical_column_with_identity(\n            key=feature_h, num_buckets=4))\n\n    table_a = 'tbl_a'\n    table_bc = 'tbl_b_c_weighted_by_c_weight_shared_embedding'\n    table_e = 'tbl_e'\n    table_fgh = 'tbl_f_g_h_shared_embedding'\n    embedding_dimension_a = 2\n    embedding_dimension_bc = 5\n    embedding_dimension_d = 2\n    embedding_dimension_e = 3\n    embedding_dimension_fgh = 4\n    column_a = tf.compat.v1.tpu.experimental.embedding_column(\n        categorical_column_a,\n        dimension=embedding_dimension_a,\n        combiner='mean',\n        initializer=lambda: 'my_initializer_a')\n    column_b, column_c = tf.compat.v1.tpu.experimental.shared_embedding_columns(\n        [categorical_column_b, weighted_column_c],\n        dimension=embedding_dimension_bc,\n        combiner='sqrtn',\n        initializer=lambda: 'my_initializer_b_c')\n    column_d = tf.compat.v1.tpu.experimental.embedding_column(\n        weighted_column_d,\n        dimension=embedding_dimension_d,\n        combiner='mean',\n        initializer=lambda: 'my_initializer_d')\n    sequence_column_e = tf.compat.v1.tpu.experimental.embedding_column(\n        sequence_categorical_column_e,\n        max_sequence_length=3,\n        dimension=embedding_dimension_e,\n        initializer=lambda: 'my_initializer_e')\n    sequence_column_f, sequence_column_g, column_h = (\n        tf.compat.v1.tpu.experimental.shared_embedding_columns(\n            [sequence_categorical_column_f, sequence_categorical_column_g,\n             categorical_column_h],\n            max_sequence_lengths=[2, 1, 0],\n            dimension=embedding_dimension_fgh,\n            initializer=lambda: 'my_initializer_f_g_h'))\n\n    table_to_config, feature_to_config = (\n        _tpu_estimator_embedding.get_configs_from_feature_columns(\n            [column_a, column_b, column_c, column_d, sequence_column_e,\n             sequence_column_f, sequence_column_g, column_h]))\n\n    self.assertEqual(feature_to_config[feature_a].table_id, table_a)\n    self.assertEqual(feature_to_config[feature_b].table_id, table_bc)\n    self.assertEqual(feature_to_config[feature_e].table_id, table_e)\n    self.assertEqual(feature_to_config[feature_f].table_id, table_fgh)\n    self.assertEqual(feature_to_config[feature_e].max_sequence_length, 3)\n    self.assertEqual(feature_to_config[feature_f].max_sequence_length, 2)\n    self.assertEqual(feature_to_config[feature_g].max_sequence_length, 1)\n    self.assertEqual(feature_to_config[feature_h].max_sequence_length, 0)\n    self.assertEqual(table_to_config[table_a].vocabulary_size, 3)\n    self.assertEqual(table_to_config[table_bc].vocabulary_size, 6)\n    self.assertEqual(table_to_config[table_e].vocabulary_size, 7)\n    self.assertEqual(table_to_config[table_fgh].vocabulary_size, 4)\n    self.assertEqual(table_to_config[table_a].dimension, embedding_dimension_a)\n    self.assertEqual(table_to_config[table_bc].dimension,\n                     embedding_dimension_bc)\n    self.assertEqual(table_to_config[table_e].dimension, embedding_dimension_e)\n    self.assertEqual(table_to_config[table_fgh].dimension,\n                     embedding_dimension_fgh)\n    self.assertEqual(table_to_config[table_a].combiner, 'mean')\n    self.assertEqual(table_to_config[table_bc].combiner, 'sqrtn')\n    self.assertEqual(table_to_config[table_a].initializer(), 'my_initializer_a')\n    self.assertEqual(table_to_config[table_bc].initializer(),\n                     'my_initializer_b_c')\n    self.assertEqual(table_to_config[table_e].initializer(), 'my_initializer_e')\n    self.assertEqual(table_to_config[table_fgh].initializer(),\n                     'my_initializer_f_g_h')\n\n    self.assertEqual(feature_to_config[feature_a].weight_key, None)\n    self.assertEqual(feature_to_config[feature_b].weight_key, None)\n    self.assertEqual(feature_to_config[feature_c].weight_key,\n                     weight_feature_key_c)\n    self.assertEqual(feature_to_config[feature_d].weight_key,\n                     weight_feature_key_d)\n\n  def _create_estimator_with_config_dicts(self,\n                                          feature_to_config_dict,\n                                          table_to_config_dict,\n                                          use_cpu=False,\n                                          partition_strategy='div',\n                                          input_method=_PER_HOST_V2):\n    \"\"\"Returns TPUEstimator which uses `feature_columns` in model_fn.\"\"\"\n\n    def _model_fn(features, labels, mode, params):\n      \"\"\"Creates simple TF model using feature_columns to create input layer.\"\"\"\n      del params\n      input_features = []\n      for feature in features:\n        if len(features[feature].shape) == 1:\n          input_features.append(tf.compat.v1.expand_dims(features[feature], -1))\n        elif len(features[feature].shape) > 2:\n          input_features.append(\n              tf.reshape(features[feature],\n                                [features[feature].shape[0], -1]))\n        else:\n          input_features.append(features[feature])\n        input_features = [\n            tf.dtypes.cast(feature, tf.float32)\n            for feature in input_features]\n        input_layer = tf.concat(input_features, -1)\n        predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n            input_layer, 1, kernel_initializer=tf.compat.v1.zeros_initializer())\n\n      loss = None\n      train_op = None\n      eval_metrics = None\n      export_outputs = None\n      if mode == model_fn_lib.ModeKeys.TRAIN:\n        loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n        optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.5)\n        optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)\n        train_op = optimizer.minimize(\n            loss, global_step=tf.compat.v1.train.get_global_step())\n      elif mode == model_fn_lib.ModeKeys.EVAL:\n        loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n\n        def metric_fn_on_cpu(labels, predictions):\n          return {\n              'mse': tf.compat.v1.metrics.mean_absolute_error(labels, predictions),\n          }\n\n        eval_metrics = (metric_fn_on_cpu, [labels, predictions])\n\n      else:\n        export_outputs = {'prediction':\n                          export_output_lib.PredictOutput(\n                              {'prediction': predictions})}\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode,\n          train_op=train_op,\n          loss=loss,\n          eval_metrics=eval_metrics,\n          export_outputs=export_outputs,\n          predictions=predictions)\n\n    run_config = create_run_config(\n        iterations_per_loop=2,\n        per_host_input_for_training=input_method)\n    embedding_config_spec = tpu_estimator.EmbeddingConfigSpec(\n        table_to_config_dict=table_to_config_dict,\n        feature_to_config_dict=feature_to_config_dict,\n        optimization_parameters=tpu_estimator.AdagradParameters(\n            learning_rate=.01, initial_accumulator=0.1),\n        partition_strategy=partition_strategy,\n    )\n    return tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=8,\n        eval_batch_size=8,\n        use_tpu=not use_cpu,\n        embedding_config_spec=embedding_config_spec,\n        export_to_tpu=True)\n\n  def _get_vocab_feature_columns(self,\n                                 embedding_initializer=None,\n                                 is_vocabulary_file=True):\n    \"\"\"Return feature columns for categorical_column_with_vocabulary_x tests.\"\"\"\n    if is_vocabulary_file:\n      vocab_file = os.path.join(tempfile.mkdtemp(), 'vocab')\n      with open(vocab_file, 'w') as f:\n        f.write('\\n'.join([str(i) for i in range(_VOCAB_SIZE)]))\n      vocab_column = tf.compat.v1.feature_column.categorical_column_with_vocabulary_file(\n          key='x',\n          vocabulary_file=vocab_file,\n          vocabulary_size=_VOCAB_SIZE,\n          num_oov_buckets=_VOCAB_NUM_BUCKETS - _VOCAB_SIZE)\n    else:\n      vocab_list = [str(i) for i in range(_VOCAB_SIZE)]\n      vocab_column = tf.feature_column.categorical_column_with_vocabulary_list(\n          key='x',\n          vocabulary_list=vocab_list,\n          num_oov_buckets=_VOCAB_NUM_BUCKETS - _VOCAB_SIZE)\n\n    feature_columns = [\n        tf.compat.v1.tpu.experimental.embedding_column(\n            categorical_column=vocab_column,\n            dimension=_VOCAB_EMBEDDING_DIM,\n            initializer=embedding_initializer),\n    ]\n    return set(feature_columns)\n\n  def _get_vocab_input_fn_and_feature_columns(self,\n                                              numeric_check=False,\n                                              is_vocabulary_file=True,\n                                              is_dataset=False):\n    \"\"\"Return input_fn and feature_columns for vocabulary column tests.\n\n    Args:\n      numeric_check: A boolean flag. When set to be True, the labels in input_fn\n          are set to be the expected input_layer value from the input features\n          and embedding initialization. This is to allow test to conveniently\n          do numerical check by comparing the labels against input_layer in\n          model_fn.\n      is_vocabulary_file: use categorical_column_with_vocabulary_file when set\n          True, else categorical_column_with_vocabulary_list.\n      is_dataset: A boolean value indicating whether the input_fn returns\n          dataset or not.\n\n    Returns:\n      A tuple consists of an input_fn and a set of feature columns.\n    \"\"\"\n\n    # Initialize embedding to\n    # 1 0 0 0 0 ..\n    # 0 2 0 0 0 ..\n    # 0 0 3 0 0 ..\n    # 0 0 0 4 0 ..\n    embedding_init = np.zeros((_VOCAB_SIZE, _VOCAB_EMBEDDING_DIM))\n    for i in range(_VOCAB_SIZE):\n      embedding_init[i, i] = i + 1\n    embedding_initializer = tf_keras_v1.initializers.constant(embedding_init)\n\n    def input_fn(params):\n      # Data index is [3, 2, 1, 0]\n      feature_data = tf.sparse.SparseTensor(\n          indices=[[i, 0] for i in range(_VOCAB_SIZE)],\n          values=[str(_VOCAB_SIZE - 1 - i) for i in range(_VOCAB_SIZE)],\n          dense_shape=[_VOCAB_SIZE, 1])\n\n      if numeric_check:\n        # Expected input_layer is\n        # 0 0 0 4 0 ..\n        # 0 0 3 0 0 ..\n        # 0 2 0 0 0 ..\n        # 1 0 0 0 0 ..\n        labels = np.zeros((_VOCAB_SIZE, _VOCAB_EMBEDDING_DIM), dtype=np.float32)\n        for i in range(_VOCAB_SIZE):\n          labels[i, _VOCAB_SIZE - i - 1] = _VOCAB_SIZE - i\n      else:\n        labels = np.zeros((_VOCAB_SIZE, 1), dtype=np.float32)\n\n      data = tf.compat.v1.data.Dataset.from_tensor_slices(({\n          'x': feature_data,\n      }, labels))\n      data = data.repeat()\n      data = data.batch(params['batch_size'], drop_remainder=True)\n      if is_dataset:\n        return data\n      iterator = data.make_one_shot_iterator()\n      return iterator.get_next()\n\n    return input_fn, self._get_vocab_feature_columns(\n        embedding_initializer, is_vocabulary_file=is_vocabulary_file)\n\n  def test_feature_in_two_embeddings(self):\n    sparse_column = tf.feature_column.categorical_column_with_identity(\n        key='x', num_buckets=10)\n    feature_columns = [\n        tf.compat.v1.tpu.experimental.embedding_column(categorical_column=sparse_column,\n                                      dimension=2),\n        tf.compat.v1.tpu.experimental.embedding_column(categorical_column=sparse_column,\n                                      dimension=4)]\n    with self.assertRaisesRegex(\n        ValueError, 'is used with multiple embeddings and this '\n        'is not supported.'):\n      est = self._create_estimator_with_feature_columns(\n          feature_columns)\n      est.train(input_fn=(lambda params: 'Not used'), steps=1)\n\n  def _test_two_features(self, shared_embedding, sequence_column,\n                         input_method, use_cpu=False):\n    sparse_column1 = tf.feature_column.categorical_column_with_identity(\n        key='x', num_buckets=10)\n    if sequence_column:\n      sparse_column2 = tf.feature_column.sequence_categorical_column_with_identity(\n          key='y', num_buckets=10)\n    else:\n      sparse_column2 = tf.feature_column.categorical_column_with_identity(\n          key='y', num_buckets=10)\n\n    if shared_embedding:\n      if sequence_column:\n        feature_columns = tf.compat.v1.tpu.experimental.shared_embedding_columns(\n            [sparse_column1, sparse_column2], dimension=2,\n            max_sequence_lengths=[0, 2])\n      else:\n        feature_columns = tf.compat.v1.tpu.experimental.shared_embedding_columns(\n            [sparse_column1, sparse_column2], dimension=2)\n    else:\n      if sequence_column:\n        feature_columns = [\n            tf.compat.v1.tpu.experimental.embedding_column(categorical_column=sparse_column1,\n                                          dimension=2),\n            tf.compat.v1.tpu.experimental.embedding_column(categorical_column=sparse_column2,\n                                          dimension=4, max_sequence_length=2)]\n      else:\n        feature_columns = [\n            tf.compat.v1.tpu.experimental.embedding_column(categorical_column=sparse_column1,\n                                          dimension=2),\n            tf.compat.v1.tpu.experimental.embedding_column(categorical_column=sparse_column2,\n                                          dimension=4)]\n\n    def _input_fn(params):\n      feature1_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          tf.sparse.SparseTensor(\n              indices=[[i, j] for i in range(params['batch_size'])\n                       for j in [0, 1]],\n              values=[1] * (2 * params['batch_size']),\n              dense_shape=[params['batch_size'], 2]))\n      feature2_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          tf.sparse.SparseTensor(\n              indices=[[i, j] for i in range(params['batch_size'])\n                       for j in [0, 1]],\n              values=[2] * (2 * params['batch_size']),\n              dense_shape=[params['batch_size'], 2]))\n      labels_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          np.array([[0]] * params['batch_size'], dtype=np.int32))\n      dataset = tf.compat.v1.data.Dataset.zip(\n          (feature1_data, feature2_data, labels_data))\n      dataset = dataset.repeat()\n      dataset = dataset.batch(params['batch_size'], drop_remainder=True)\n\n      def _map(x, y, z):\n        return {'x': x, 'y': y}, z\n\n      dataset = dataset.map(_map)\n      return dataset\n\n    est = self._create_estimator_with_feature_columns(feature_columns,\n                                                      use_cpu=use_cpu,\n                                                      input_method=input_method)\n    est.train(input_fn=_input_fn, steps=1)\n    checkpoint_reader = tf.compat.v1.train.NewCheckpointReader(\n        tf.train.latest_checkpoint(est.config.model_dir))\n    return checkpoint_reader.get_variable_to_shape_map().keys()\n\n  @parameterized.named_parameters(\n      ('non_shared_non_sequence_v1', False, False, _PER_HOST_V1),\n      ('shared_non_sequence_v1', True, False, _PER_HOST_V1),\n      ('non_shared_sequence_v1', False, True, _PER_HOST_V1),\n      ('shared_sequence_v1', True, True, _PER_HOST_V1),\n      ('non_shared_non_sequence_v2', False, False, _PER_HOST_V2),\n      ('shared_non_sequence_v2', True, False, _PER_HOST_V2),\n      ('non_shared_sequence_v2', False, True, _PER_HOST_V2),\n      ('shared_sequence_v2', True, True, _PER_HOST_V2))\n  def test_two_features_with_config_dicts(self,\n                                          shared_embedding,\n                                          sequence_column,\n                                          input_method):\n    y_max_seq_length = 2 if sequence_column else 0\n    y_table = 't1' if shared_embedding else 't2'\n    feature_to_config_dict = {\n        'x': tpu_embedding.FeatureConfig(table_id='t1'),\n        'y': tpu_embedding.FeatureConfig(table_id=y_table,\n                                         max_sequence_length=y_max_seq_length)\n    }\n    table_to_config_dict = {\n        't1': tpu_embedding.TableConfig(vocabulary_size=10, dimension=2)\n    }\n    if not shared_embedding:\n      table_to_config_dict['t2'] = tpu_embedding.TableConfig(\n          vocabulary_size=10, dimension=4)\n\n    def _input_fn(params):\n      feature1_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          tf.sparse.SparseTensor(\n              indices=list(\n                  itertools.product(range(params['batch_size']), [0, 1])),\n              values=[1] * (2 * params['batch_size']),\n              dense_shape=[params['batch_size'], 2]))\n      feature2_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          tf.sparse.SparseTensor(\n              indices=list(\n                  itertools.product(range(params['batch_size']), [0, 1])),\n              values=[2] * (2 * params['batch_size']),\n              dense_shape=[params['batch_size'], 2]))\n      labels_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          np.array([[0]] * params['batch_size'], dtype=np.int32))\n      dataset = tf.compat.v1.data.Dataset.zip(\n          (feature1_data, feature2_data, labels_data))\n      dataset = dataset.repeat()\n      dataset = dataset.batch(params['batch_size'], drop_remainder=True)\n\n      def _map(x, y, z):\n        return {'x': x, 'y': y}, z\n\n      dataset = dataset.map(_map)\n      return dataset\n\n    est = self._create_estimator_with_config_dicts(feature_to_config_dict,\n                                                   table_to_config_dict,\n                                                   input_method=input_method)\n    est.train(input_fn=_input_fn, steps=1)\n\n  def test_non_tpu_embedding_column(self):\n    sparse_column = tf.feature_column.categorical_column_with_identity(\n        key='x', num_buckets=10)\n    sparse_column2 = tf.feature_column.categorical_column_with_identity(\n        key='y', num_buckets=10)\n    feature_columns = [\n        tf.compat.v1.tpu.experimental.embedding_column(\n            categorical_column=sparse_column, dimension=2),\n        tf.feature_column.embedding_column(categorical_column=sparse_column2, dimension=4)\n    ]\n\n    with self.assertRaisesRegex(TypeError, 'Unsupported feature column'):\n      est = self._create_estimator_with_feature_columns(\n          feature_columns)\n      est.train(input_fn=(lambda params: 'Not used'), steps=1)\n\n  def test_feature_in_embedding_and_shared_embedding(self):\n    sparse_column1 = tf.feature_column.categorical_column_with_identity(\n        key='x', num_buckets=10)\n    sparse_column2 = tf.feature_column.categorical_column_with_identity(\n        key='y', num_buckets=10)\n\n    feature_columns = [\n        tf.compat.v1.tpu.experimental.embedding_column(categorical_column=sparse_column1,\n                                      dimension=2)\n    ] + tf.compat.v1.tpu.experimental.shared_embedding_columns([sparse_column1, sparse_column2],\n                                              dimension=4)\n\n    with self.assertRaisesRegex(\n        ValueError, 'is used with multiple embeddings and this '\n        'is not supported.'):\n      est = self._create_estimator_with_feature_columns(\n          feature_columns)\n      est.train(input_fn=(lambda params: 'Not used'), steps=1)\n\n  def test_sequence_column_with_no_max_length(self):\n    sparse_column = tf.feature_column.sequence_categorical_column_with_identity(\n        key='x', num_buckets=10)\n    with self.assertRaisesRegex(\n        ValueError, 'max_sequence_length must be greater than 0 '\n        'for sequence columns. Got max_sequence_length'\n        '=0 for sequence column x.'):\n      tf.compat.v1.tpu.experimental.embedding_column(categorical_column=sparse_column,\n                                    dimension=2)\n\n  def test_non_sequence_column_with_max_length(self):\n    sparse_column = tf.feature_column.categorical_column_with_identity(\n        key='x', num_buckets=10)\n    with self.assertRaisesRegex(\n        ValueError, 'Non zero max_seq_length=2 specified for non '\n        'sequence column x.'):\n      tf.compat.v1.tpu.experimental.embedding_column(categorical_column=sparse_column,\n                                    dimension=2,\n                                    max_sequence_length=2)\n\n  def test_sequence_column_shared_embedding_wrong_max_sequence_length(self):\n    sparse_column_x = tf.feature_column.sequence_categorical_column_with_identity(\n        key='x', num_buckets=10)\n    sparse_column_y = tf.feature_column.sequence_categorical_column_with_identity(\n        key='y', num_buckets=10)\n    with self.assertRaisesRegex(\n        ValueError, 'max_sequence_lengths and categorical_columns must be of'):\n      tf.compat.v1.tpu.experimental.shared_embedding_columns(\n          categorical_columns=[sparse_column_x, sparse_column_y], dimension=2,\n          max_sequence_lengths=[2])\n\n  def test_sequence_column_shared_embedding_non_sequence_with_max_length(self):\n    sparse_column_x = tf.feature_column.sequence_categorical_column_with_identity(\n        key='x', num_buckets=10)\n    sparse_column_y = tf.feature_column.categorical_column_with_identity(\n        key='y', num_buckets=10)\n    with self.assertRaisesRegex(ValueError,\n                                'Non zero max_seq_length=1 specified for non'):\n      tf.compat.v1.tpu.experimental.shared_embedding_columns(\n          categorical_columns=[sparse_column_x, sparse_column_y], dimension=2,\n          max_sequence_lengths=[2, 1])\n\n  def test_sequence_column_shared_embedding_sequence_without_max_length(self):\n    sparse_column_x = tf.feature_column.sequence_categorical_column_with_identity(\n        key='x', num_buckets=10)\n    sparse_column_y = tf.feature_column.categorical_column_with_identity(\n        key='y', num_buckets=10)\n    with self.assertRaisesRegex(ValueError,\n                                'max_sequence_length must be greater than 0'):\n      tf.compat.v1.tpu.experimental.shared_embedding_columns(\n          categorical_columns=[sparse_column_x, sparse_column_y], dimension=2)\n\n  @parameterized.named_parameters(\n      ('per_host_v1', _PER_HOST_V1),\n      ('per_host_v2', _PER_HOST_V2))\n  def test_sequence_column_length(self, input_method):\n    sequence_column = tf.feature_column.sequence_categorical_column_with_identity(\n        key='x', num_buckets=10)\n    feature_columns = [\n        tf.compat.v1.tpu.experimental.embedding_column(\n            categorical_column=sequence_column,\n            dimension=4,\n            max_sequence_length=10)\n    ]\n\n    def _input_fn(params):\n      sequence_lengths = np.random.randint(1, 10, params['batch_size'])\n      total = sum(sequence_lengths)\n      indices = []\n      for i in range(params['batch_size']):\n        for j in range(sequence_lengths[i]):\n          indices.append([i, j])\n      feature_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          tf.sparse.SparseTensor(\n              indices=indices,\n              values=[1] * total,\n              dense_shape=[params['batch_size'], 10])\n      )\n      labels_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          np.array(sequence_lengths, dtype=np.float32))\n      dataset = tf.compat.v1.data.Dataset.zip(\n          (feature_data, labels_data))\n      dataset = dataset.repeat()\n      dataset = dataset.batch(params['batch_size'], drop_remainder=True)\n      def _map(x, y):\n        return {'x': x}, y\n      dataset = dataset.map(_map)\n      return dataset\n\n    est = self._create_estimator_with_feature_columns(\n        feature_columns, numeric_check=True, input_method=input_method)\n    res = est.evaluate(input_fn=_input_fn, steps=1)\n    self.assertAllClose(res['loss'], 0)\n\n  def test_unknown_partition_strategy(self):\n    feature_to_config_dict = {'x': tpu_embedding.FeatureConfig(table_id='t1')}\n    table_to_config_dict = {\n        't1': tpu_embedding.TableConfig(vocabulary_size=10, dimension=2)\n    }\n    with self.assertRaisesRegex(\n        ValueError, 'Invalid partition_strategy invalid. Must be '\n        'one of \"mod\" or \"div\".'):\n      self._create_estimator_with_config_dicts(\n          feature_to_config_dict, table_to_config_dict, use_cpu=True,\n          partition_strategy='invalid')\n\n  def test_mod_partition_strategy_on_cpu(self):\n    feature_to_config_dict = {'x': tpu_embedding.FeatureConfig(table_id='t1')}\n    table_to_config_dict = {\n        't1': tpu_embedding.TableConfig(vocabulary_size=10, dimension=2)\n    }\n    with self.assertRaisesRegex(\n        ValueError, 'Mod sharding of embedding tables not '\n        'supported on CPU.'):\n      self._create_estimator_with_config_dicts(\n          feature_to_config_dict, table_to_config_dict, use_cpu=True,\n          partition_strategy='mod')\n\n  @parameterized.named_parameters(\n      ('non_shared_non_sequence_v1', False, False, _PER_HOST_V1),\n      ('shared_non_sequence_v1', True, False, _PER_HOST_V1),\n      ('non_shared_sequence_v1', False, True, _PER_HOST_V1),\n      ('shared_sequence_v1', True, True, _PER_HOST_V1),\n      ('non_shared_non_sequence_v2', False, False, _PER_HOST_V2),\n      ('shared_non_sequence_v2', True, False, _PER_HOST_V2),\n      ('non_shared_sequence_v2', False, True, _PER_HOST_V2),\n      ('shared_sequence_v2', True, True, _PER_HOST_V2))\n  def test_two_features(self, shared, sequence, input_method):\n    cpu_names = self._test_two_features(shared_embedding=shared,\n                                        sequence_column=sequence,\n                                        input_method=input_method,\n                                        use_cpu=True)\n    tpu_names = self._test_two_features(shared_embedding=shared,\n                                        sequence_column=sequence,\n                                        input_method=input_method,\n                                        use_cpu=False)\n    # TPU will have some extra variables but all CPU variables should be in the\n    # TPU checkpoint\n    for name in cpu_names:\n      self.assertIn(name, tpu_names)\n\n  @parameterized.named_parameters(\n      ('per_host_v1', _PER_HOST_V1),\n      ('per_host_v2', _PER_HOST_V2))\n  def test_dynamic_learning_rate(self, input_method):\n    sparse_column_a = tf.feature_column.categorical_column_with_identity(\n        key='a', num_buckets=10)\n    sparse_column_b = tf.feature_column.categorical_column_with_identity(\n        key='b', num_buckets=10)\n    sparse_column_c = tf.feature_column.categorical_column_with_identity(\n        key='c', num_buckets=10)\n    sparse_column_d = tf.feature_column.categorical_column_with_identity(\n        key='d', num_buckets=10)\n    sparse_column_e = tf.feature_column.categorical_column_with_identity(\n        key='e', num_buckets=10)\n    sparse_column_f = tf.feature_column.categorical_column_with_identity(\n        key='f', num_buckets=10)\n\n    static_lr = 1\n    def dynamic_learning_rate(global_step):\n      return tf.compat.v1.cond(\n          tf.math.equal(global_step, 0), lambda: 2, lambda: 0)\n\n    def shared_dynamic_learning_rate(global_step):\n      return tf.compat.v1.cond(\n          tf.math.equal(global_step, 0), lambda: 3, lambda: 0)\n\n    embedding_column_static = tf.compat.v1.tpu.experimental.embedding_column(\n        categorical_column=sparse_column_a,\n        dimension=2,\n        initializer=tf.compat.v1.ones_initializer())\n    embedding_column_dynamic = tf.compat.v1.tpu.experimental.embedding_column(\n        categorical_column=sparse_column_b,\n        dimension=2,\n        initializer=tf.compat.v1.ones_initializer(),\n        learning_rate_fn=dynamic_learning_rate)\n    shared_embedding_columns_static = tf.compat.v1.tpu.experimental.shared_embedding_columns(\n        [sparse_column_c, sparse_column_d],\n        dimension=2,\n        initializer=tf.compat.v1.ones_initializer())\n    shared_embedding_columns_dynamic = tf.compat.v1.tpu.experimental.shared_embedding_columns(\n        [sparse_column_e, sparse_column_f],\n        dimension=2,\n        initializer=tf.compat.v1.ones_initializer(),\n        learning_rate_fn=shared_dynamic_learning_rate)\n    feature_columns = ([embedding_column_static] + [embedding_column_dynamic] +\n                       shared_embedding_columns_static +\n                       shared_embedding_columns_dynamic)\n\n    def _input_fn(params):\n      feature_indices = [[0, 0], [1, 0], [1, 1], [1, 2]]\n      feature_values = [3, 0, 1, 2]\n      feature = tf.sparse.SparseTensor(\n          indices=feature_indices,\n          values=feature_values,\n          dense_shape=[2, 3])\n      feature_datas = tuple(\n          tf.compat.v1.data.Dataset.from_tensor_slices(feature) for _ in range(6))\n      labels_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          np.array([[0]] * 2, dtype=np.int32))\n      dataset = tf.compat.v1.data.Dataset.zip(feature_datas + (labels_data,))\n      dataset = dataset.repeat()\n\n      def _map(a, b, c, d, e, f, z):\n        return {'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'f': f}, z\n\n      dataset = dataset.map(_map)\n      dataset = dataset.batch(params['batch_size'], drop_remainder=True)\n      return dataset\n\n    def _model_fn(features, labels, mode, params):\n      \"\"\"Creates simple TF model using feature_columns to create input layer.\"\"\"\n      del params\n      assert mode == model_fn_lib.ModeKeys.TRAIN\n\n      dense_layer = tf_keras_v1.layers.DenseFeatures(feature_columns)\n      input_layer = dense_layer(features)\n      predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n          input_layer, 1, kernel_initializer=tf.compat.v1.ones_initializer())\n\n      loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n      optimizer = tf.compat.v1.train.AdagradOptimizer(learning_rate=0.5)\n      optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)\n      train_op = optimizer.minimize(\n          loss, global_step=tf.compat.v1.train.get_global_step())\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode,\n          train_op=train_op,\n          loss=loss)\n\n    run_config = create_run_config(\n        iterations_per_loop=1,\n        per_host_input_for_training=input_method)\n    optimization_parameters = (\n        tpu_estimator.StochasticGradientDescentParameters(\n            learning_rate=static_lr))\n    embedding_config_spec = tpu_estimator.EmbeddingConfigSpec(\n        feature_columns=feature_columns,\n        optimization_parameters=optimization_parameters)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=4,\n        embedding_config_spec=embedding_config_spec)\n    est.train(input_fn=_input_fn, steps=1)\n\n    checkpoint_reader = tf.compat.v1.train.NewCheckpointReader(\n        tf.train.latest_checkpoint(est.config.model_dir))\n    embedding_static = checkpoint_reader.get_tensor(\n        'dense_features/a_embedding/embedding_weights')\n    embedding_dynamic = checkpoint_reader.get_tensor(\n        'dense_features/b_embedding/embedding_weights')\n    shared_embedding_static = checkpoint_reader.get_tensor(\n        'c_d_shared_embedding')\n    shared_embedding_dynamic = checkpoint_reader.get_tensor(\n        'e_f_shared_embedding')\n\n    unit_update = embedding_static - 1.\n    unit_update_shared = shared_embedding_static - 1.\n    # The asserts below are only valid if unit updates are not all zero.\n    self.assertFalse(np.allclose(unit_update, 0.))\n    self.assertFalse((np.allclose(unit_update_shared, 0.)))\n    self.assertAllClose(embedding_dynamic - 1.,\n                        unit_update * 2)\n    self.assertAllClose(shared_embedding_dynamic - 1.,\n                        unit_update_shared * 3)\n\n    # train for another step\n    est.train(input_fn=_input_fn, steps=1)\n\n    checkpoint_reader2 = tf.compat.v1.train.NewCheckpointReader(\n        tf.train.latest_checkpoint(est.config.model_dir))\n    embedding_static2 = checkpoint_reader2.get_tensor(\n        'dense_features/a_embedding/embedding_weights')\n    embedding_dynamic2 = checkpoint_reader2.get_tensor(\n        'dense_features/b_embedding/embedding_weights')\n    shared_embedding_static2 = checkpoint_reader2.get_tensor(\n        'c_d_shared_embedding')\n    shared_embedding_dynamic2 = checkpoint_reader2.get_tensor(\n        'e_f_shared_embedding')\n\n    self.assertFalse(np.allclose(embedding_static,\n                                 embedding_static2))\n    self.assertFalse((np.allclose(shared_embedding_static,\n                                  shared_embedding_static2)))\n    self.assertAllClose(embedding_dynamic, embedding_dynamic2)\n    self.assertAllClose(shared_embedding_dynamic, shared_embedding_dynamic2)\n\n\nclass TPUEstimatorWeightedFeatureColumnTest(TPUEstimatorFeatureColumnTestBase,\n                                            parameterized.TestCase):\n\n  @parameterized.named_parameters(\n      ('per_host_v1', _PER_HOST_V1),\n      ('per_host_v2', _PER_HOST_V2))\n  def test_embedding_with_weighted_categorical_column(self, input_method):\n    num_buckets = 3\n    embedding_dim = 5\n    sparse_id_column = tf.feature_column.categorical_column_with_identity(\n        key='ids', num_buckets=num_buckets)\n    weighted_sparse_id_column = tf.feature_column.weighted_categorical_column(\n        categorical_column=sparse_id_column, weight_feature_key='values')\n\n    embedding_init = np.zeros((num_buckets, embedding_dim))\n    # Embedding initialized to\n    # 1 1 1 1 1\n    # 2 2 2 2 2\n    # 3 3 3 3 3\n    for i in range(num_buckets):\n      embedding_init[i, :] = [i + 1] * embedding_dim\n\n    feature_columns = [\n        tf.compat.v1.tpu.experimental.embedding_column(\n            categorical_column=weighted_sparse_id_column,\n            dimension=embedding_dim,\n            combiner='mean',\n            initializer=tf_keras_v1.initializers.constant(embedding_init))\n    ]\n\n    def _input_fn(params):\n      sample_size = 2\n      dense_shape = (sample_size, num_buckets)\n      indices = ((0, 0), (0, 2), (1, 0), (1, 1))\n      id_values = (2, 1, 1, 0)\n      weight_values = (0.5, 1.0, 0.2, 0.0)\n\n      inputs = tf.sparse.SparseTensor(\n          indices=indices, values=id_values, dense_shape=dense_shape)\n      weights = tf.sparse.SparseTensor(\n          indices=indices, values=weight_values, dense_shape=dense_shape)\n\n      # Setup labels so that the loss is zero\n      labels = np.zeros((sample_size, embedding_dim), dtype=np.float32)\n      for j in range(embedding_dim):\n        # \"mean\" is the weighted sum divided by the total weight.\n        labels[0, j] = (3 * 0.5 + 2 * 1.0) / (0.5 + 1.0)\n        labels[1, j] = (2 * 0.2 + 1 * 0.0) / (0.2 + 0.0)\n\n      data = tf.compat.v1.data.Dataset.from_tensor_slices(({\n          'ids': inputs,\n          'values': weights\n      }, labels))\n      data = data.repeat()\n      data = data.batch(params['batch_size'], drop_remainder=True)\n      return data\n\n    est = self._create_estimator_with_feature_columns(\n        feature_columns, numeric_check=True, input_method=input_method)\n    est.train(input_fn=_input_fn, steps=1)\n    res = est.evaluate(input_fn=_input_fn, steps=1)\n    self.assertAllClose(res['loss'], 0)\n\n  @parameterized.named_parameters(\n      ('per_host_v1', _PER_HOST_V1),\n      ('per_host_v2', _PER_HOST_V2))\n  def test_shared_embedding_with_weighted_categorical_column_and_dataset(\n      self, input_method):\n    num_buckets = 3\n    embedding_dim = 5\n    sparse_id_column1 = tf.feature_column.categorical_column_with_identity(\n        key='ids1', num_buckets=num_buckets)\n    weighted_sparse_id_column = tf.feature_column.weighted_categorical_column(\n        categorical_column=sparse_id_column1, weight_feature_key='values')\n    sparse_id_column2 = tf.feature_column.categorical_column_with_identity(\n        key='ids2', num_buckets=num_buckets)\n\n    # Embedding initialized to\n    # 1 1 1 1 1\n    # 2 2 2 2 2\n    # 3 3 3 3 3\n    embedding_init = np.zeros((num_buckets, embedding_dim))\n    for i in range(num_buckets):\n      embedding_init[i, :] = [i + 1] * embedding_dim\n\n    feature_columns = tf.compat.v1.tpu.experimental.shared_embedding_columns(\n        categorical_columns=[weighted_sparse_id_column, sparse_id_column2],\n        dimension=embedding_dim,\n        combiner='sum',\n        initializer=tf_keras_v1.initializers.constant(embedding_init))\n\n    def _input_fn(params):\n      sample_size = 2\n      dense_shape = (sample_size, num_buckets)\n      id1_indices = ((0, 0), (0, 2), (1, 0), (1, 1))\n      id1_values = (2, 1, 1, 0)\n      id1_weight_values = (1, 2, 3, 0)  # test integer weights\n      id2_indices = ((0, 1), (1, 0), (1, 2))\n      id2_values = (1, 2, 0)\n\n      inputs1 = tf.sparse.SparseTensor(\n          indices=id1_indices, values=id1_values, dense_shape=dense_shape)\n      inputs1_weights = tf.sparse.SparseTensor(\n          indices=id1_indices,\n          values=id1_weight_values,\n          dense_shape=dense_shape)\n      inputs2 = tf.sparse.SparseTensor(\n          indices=id2_indices, values=id2_values, dense_shape=dense_shape)\n\n      # Setup labels so that the loss is zero\n      labels = np.zeros((2, embedding_dim * 2), dtype=np.float32)\n      for j in range(embedding_dim):\n        labels[0, j] = 3 * 1 + 2 * 2\n        labels[1, j] = 2 * 3\n      for j in range(embedding_dim, embedding_dim * 2):\n        labels[0, j] = 2\n        labels[1, j] = 3 + 1\n\n      data = tf.compat.v1.data.Dataset.from_tensor_slices(({\n          'ids1': inputs1,\n          'ids2': inputs2,\n          'values': inputs1_weights\n      }, labels))\n      data = data.repeat()\n      data = data.batch(params['batch_size'], drop_remainder=True)\n      return data\n\n    est = self._create_estimator_with_feature_columns(\n        feature_columns, numeric_check=True, input_method=input_method)\n    est.train(input_fn=_input_fn, steps=1)\n    res = est.evaluate(input_fn=_input_fn, steps=1)\n    self.assertAllClose(res['loss'], 0)\n\n  def test_embedding_with_with_weighted_categorical_column_with_vocab_error(\n      self):\n    vocab_list = [str(i) for i in range(_VOCAB_SIZE)]\n    vocab_column = tf.feature_column.categorical_column_with_vocabulary_list(\n        key='x',\n        vocabulary_list=vocab_list,\n        num_oov_buckets=_VOCAB_NUM_BUCKETS - _VOCAB_SIZE)\n    weighted_vocab_column = tf.feature_column.weighted_categorical_column(\n        categorical_column=vocab_column, weight_feature_key='values')\n\n    # Embedding initialized to\n    # 0 0 0 0 0 ...\n    # 1 1 1 1 1 ...\n    # 2 2 2 2 2 ...\n    # 3 3 3 3 3 ...\n    embedding_init = np.zeros((_VOCAB_SIZE, _VOCAB_EMBEDDING_DIM))\n    for i in range(_VOCAB_SIZE):\n      embedding_init[i, :] = [i] * _VOCAB_EMBEDDING_DIM\n    embedding_initializer = tf_keras_v1.initializers.constant(embedding_init)\n\n    feature_columns = [\n        tf.compat.v1.tpu.experimental.embedding_column(\n            categorical_column=weighted_vocab_column,\n            dimension=_VOCAB_EMBEDDING_DIM,\n            initializer=embedding_initializer),\n    ]\n\n    def _input_fn(params):\n      # Dense data after vocab -> integer conversion\n      # 2 _ 1 _ _\n      # 1 3 _ _ _\n      sample_size = 2\n      dense_shape = (sample_size, _VOCAB_NUM_BUCKETS)\n      indices = ((0, 0), (0, 2), (1, 0), (1, 1))\n      id_values = [str(id_value) for id_value in (2, 1, 1, 0)]\n      weight_values = (0.5, 1.0, 0.2, 0.0)\n      inputs = tf.sparse.SparseTensor(\n          indices=indices, values=(id_values), dense_shape=dense_shape)\n\n      inputs = tf.sparse.SparseTensor(\n          indices=indices, values=id_values, dense_shape=dense_shape)\n      weights = tf.sparse.SparseTensor(\n          indices=indices, values=weight_values, dense_shape=dense_shape)\n\n      # setup labels to be the same as what input_layer so the loss is zero\n      # Expected input_layer is\n      labels = np.zeros((sample_size, _VOCAB_EMBEDDING_DIM), dtype=np.float32)\n      for j in range(_VOCAB_EMBEDDING_DIM):\n        # \"mean\" is the weighted sum divided by the total weight.\n        labels[0, j] = (2 * 0.5 + 1 * 1.0) / (0.5 + 1.0)\n        labels[1, j] = (1 * 0.2 + 3 * 0.0) / (0.2 + 0.0)\n\n      data = tf.compat.v1.data.Dataset.from_tensor_slices(({\n          'x': inputs,\n          'values': weights,\n      }, labels))\n      data = data.repeat()\n      data = data.batch(params['batch_size'], drop_remainder=True)\n      return data\n\n    est = self._create_estimator_with_feature_columns(\n        feature_columns, numeric_check=True)\n    with self.assertRaisesRegex(\n        ValueError, 'SparseTensor with string as values are not supported.'):\n      est.train(input_fn=_input_fn, steps=1)\n\n  def test_embedding_with_weighted_categorical_column_dense_weights_error(self):\n    num_buckets = 5\n    embedding_dim = 10\n    sparse_id_column = tf.feature_column.categorical_column_with_identity(\n        key='ids', num_buckets=num_buckets)\n    weighted_sparse_id_column = tf.feature_column.weighted_categorical_column(\n        categorical_column=sparse_id_column, weight_feature_key='values')\n\n    feature_columns = [\n        tf.compat.v1.tpu.experimental.embedding_column(\n            categorical_column=weighted_sparse_id_column,\n            dimension=embedding_dim)\n    ]\n\n    def _input_fn(params):\n      sample_size = 2\n      dense_shape = (sample_size, num_buckets)\n      indices = ((0, 0), (0, 2), (1, 0), (1, 1))\n      id_values = (2, 1, 1, 0)\n      weight_values = (0.5, 1.0, 0.2, 0.0)\n\n      inputs = tf.sparse.SparseTensor(\n          indices=indices, values=id_values, dense_shape=dense_shape)\n      weights = tf.sparse.to_dense(\n          tf.sparse.SparseTensor(\n              indices=indices, values=weight_values, dense_shape=dense_shape))\n\n      labels = np.zeros((sample_size, embedding_dim), dtype=np.float32)\n\n      data = tf.compat.v1.data.Dataset.from_tensor_slices(({\n          'ids': inputs,\n          'values': weights\n      }, labels))\n      data = data.repeat()\n      data = data.batch(params['batch_size'], drop_remainder=True)\n      return data\n\n    est = self._create_estimator_with_feature_columns(\n        feature_columns, numeric_check=True)\n    with self.assertRaisesRegex(ValueError,\n                                'Dense weights are not supported on TPU'):\n      est.train(input_fn=_input_fn, steps=1)\n\n  def test_embedding_with_weighted_categorical_column_share_weights_error(self):\n    num_buckets = 5\n    embedding_dim = 10\n    sparse_id_column1 = tf.feature_column.categorical_column_with_identity(\n        key='ids1', num_buckets=num_buckets)\n    weighted_sparse_id_column1 = tf.feature_column.weighted_categorical_column(\n        categorical_column=sparse_id_column1, weight_feature_key='values')\n    sparse_id_column2 = tf.feature_column.categorical_column_with_identity(\n        key='ids2', num_buckets=num_buckets)\n    weighted_sparse_id_column2 = tf.feature_column.weighted_categorical_column(\n        categorical_column=sparse_id_column2, weight_feature_key='values')\n\n    feature_columns = tf.compat.v1.tpu.experimental.shared_embedding_columns(\n        categorical_columns=[\n            weighted_sparse_id_column1, weighted_sparse_id_column2\n        ],\n        dimension=embedding_dim)\n\n    def _input_fn(params):\n      sample_size = 2\n      dense_shape = (sample_size, num_buckets)\n      indices = ((0, 0), (0, 2), (1, 0), (1, 1))\n      id_values = (2, 1, 1, 0)\n      weight_values = (0.5, 1.0, 0.2, 0.0)\n\n      inputs1 = tf.sparse.SparseTensor(\n          indices=indices, values=id_values, dense_shape=dense_shape)\n      inputs2 = tf.sparse.SparseTensor(\n          indices=indices, values=id_values, dense_shape=dense_shape)\n      weights = tf.sparse.SparseTensor(\n          indices=indices, values=weight_values, dense_shape=dense_shape)\n\n      labels = np.zeros((sample_size, embedding_dim), dtype=np.float32)\n\n      data = tf.compat.v1.data.Dataset.from_tensor_slices(({\n          'ids1': inputs1,\n          'ids2': inputs2,\n          'values': weights\n      }, labels))\n      data = data.repeat()\n      data = data.batch(params['batch_size'], drop_remainder=True)\n      return data\n\n    est = self._create_estimator_with_feature_columns(\n        feature_columns, numeric_check=True)\n    with self.assertRaisesRegex(\n        ValueError,\n        'Please check if the weights are present in feature dict. Also note'\n        ' weight-sharing among weighted_categorical_column is not supported on '\n        'TPU.'):\n      est.train(input_fn=_input_fn, steps=1)\n\n  def _test_tensor_core_embedding(self,\n                                  shared_embedding,\n                                  both_embeddings,\n                                  input_method,\n                                  use_cpu=False):\n    sparse_column1 = tf.feature_column.categorical_column_with_identity(\n        key='x', num_buckets=10)\n    sparse_column2 = tf.feature_column.categorical_column_with_identity(\n        key='y', num_buckets=10)\n\n    if shared_embedding:\n      feature_columns = tf.compat.v1.tpu.experimental.shared_embedding_columns(\n          [sparse_column1, sparse_column2],\n          dimension=2,\n          embedding_lookup_device='tpu_tensor_core',\n          tensor_core_shape=[None, 2])\n    else:\n      feature_columns = [\n          tf.compat.v1.tpu.experimental.embedding_column(\n              categorical_column=sparse_column1,\n              dimension=2,\n              embedding_lookup_device='tpu_tensor_core',\n              tensor_core_shape=[None, 2]),\n      ]\n      if both_embeddings:\n        feature_columns.append(\n            tf.compat.v1.tpu.experimental.embedding_column(\n                categorical_column=sparse_column2,\n                dimension=4,\n                embedding_lookup_device='tpu_tensor_core',\n                tensor_core_shape=[None, 2]))\n      else:\n        feature_columns.append(\n            tf.compat.v1.tpu.experimental.embedding_column(\n                categorical_column=sparse_column2, dimension=4))\n\n    def _input_fn(params):\n      indices = []\n      for i in range(params['batch_size']):\n        for j in [0, 1]:\n          indices.append([i, j])\n      feature1_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          tf.sparse.SparseTensor(\n              indices=indices,\n              values=[1] * (2 * params['batch_size']),\n              dense_shape=[params['batch_size'], 2]))\n      feature2_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          tf.sparse.SparseTensor(\n              indices=indices,\n              values=[2] * (2 * params['batch_size']),\n              dense_shape=[params['batch_size'], 2]))\n      labels_data = tf.compat.v1.data.Dataset.from_tensor_slices(\n          np.array([[0]] * params['batch_size'], dtype=np.int32))\n      dataset = tf.compat.v1.data.Dataset.zip(\n          (feature1_data, feature2_data, labels_data))\n      dataset = dataset.repeat()\n      dataset = dataset.batch(params['batch_size'], drop_remainder=True)\n\n      def _map(x, y, z):\n        return {'x': x, 'y': y}, z\n\n      dataset = dataset.map(_map)\n      return dataset\n\n    est = self._create_estimator_with_feature_columns(\n        feature_columns, use_cpu=use_cpu, input_method=input_method)\n    est.train(input_fn=_input_fn, steps=1)\n    checkpoint_reader = tf.compat.v1.train.NewCheckpointReader(\n        tf.train.latest_checkpoint(est.config.model_dir))\n    return checkpoint_reader.get_variable_to_shape_map().keys()\n\n  @parameterized.named_parameters(\n      ('non_shared_single_v1', False, False, _PER_HOST_V1),\n      ('non_shared_both_v1', False, True, _PER_HOST_V1),\n      ('shared_v1', True, True, _PER_HOST_V1),\n      ('non_shared_single_v2', False, False, _PER_HOST_V2),\n      ('non_shared_both_v2', False, True, _PER_HOST_V2),\n      ('shared_v2', True, True, _PER_HOST_V2))\n  def test_tensor_core_embedding(self, shared, both_embeddings, input_method):\n    cpu_names = self._test_tensor_core_embedding(\n        shared_embedding=shared,\n        both_embeddings=both_embeddings,\n        input_method=input_method,\n        use_cpu=True)\n    tpu_names = self._test_tensor_core_embedding(\n        shared_embedding=shared,\n        both_embeddings=both_embeddings,\n        input_method=input_method,\n        use_cpu=False)\n    # TPU will have some extra variables but all CPU variables should be in the\n    # TPU checkpoint\n    for name in cpu_names:\n      self.assertIn(name, tpu_names)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator_evaluation_test.py",
    "content": "# Copyright 2021 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for TPUEstimator evaluation related functionalities.\"\"\"\n\nfrom absl import flags\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.training import evaluation\n\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.export import export_output as export_output_lib\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\n# pylint: enable=g-direct-tensorflow-import\n\nflags.DEFINE_integer('test_num_shards', 8, 'number of replicas to test')\n\n\nFLAGS = flags.FLAGS\n\n_TRAIN = model_fn_lib.ModeKeys.TRAIN\n_EVAL = model_fn_lib.ModeKeys.EVAL\n_PREDICT = model_fn_lib.ModeKeys.PREDICT\n\n_PER_HOST = 'per_host_sharding'\n_PER_SHARD = 'per_shard_sharding'\n_UNSHARDED = 'unsharded'\n_INPUT_PIPELINE_WITH_QUEUE_RUNNER = (\n    'Input pipeline contains one or more QueueRunners')\n\n\ndef dense_computation(features):\n  return tf_keras_v1.__internal__.legacy.layers.dense(\n      features['x'], 1, kernel_initializer=tf.compat.v1.zeros_initializer())\n\n\ndef get_model_fn(export_tpu_tensor=True, export_cpu_tensor=False,\n                 tpu_estimator_spec=True):\n\n  def model_fn(features, labels, mode, params):\n    del params\n    loss = None\n    train_op = None\n    predictions = dense_computation(features)\n    export_outputs = None\n    if mode != _PREDICT:\n      loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n      optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n          tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.5))\n      train_op = optimizer.minimize(loss, tf.compat.v1.train.get_global_step())\n    else:\n      if export_tpu_tensor:\n        key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n        export_outputs = {\n            key: export_output_lib.PredictOutput({\n                'prediction': predictions\n            })\n        }\n      else:\n        export_outputs = {}\n\n      if export_cpu_tensor:\n\n        def host_call(predictions):\n          classes = tf.as_string(predictions, name='classes')\n          classification_output = export_output_lib.ClassificationOutput(\n              classes=classes)\n          export_outputs['classification'] = classification_output\n\n        tf.compat.v1.tpu.outside_compilation(host_call, predictions)\n\n    if tpu_estimator_spec:\n      spec_type = tpu_estimator.TPUEstimatorSpec\n    else:\n      spec_type = model_fn_lib.EstimatorSpec\n\n    return spec_type(\n        mode,\n        loss=loss,\n        train_op=train_op,\n        predictions={'predictions': predictions},\n        export_outputs=export_outputs)\n\n  return model_fn\n\n\ndef dummy_input_fn_with_dataset(batch_size, repeat=True, x=None):\n  if x is None:\n    x = np.random.normal(size=[batch_size, 1]).astype(np.float32)\n  labels = [[2.0]] * batch_size\n\n  dataset1 = tf.compat.v1.data.Dataset.from_tensor_slices(x)\n  dataset2 = tf.compat.v1.data.Dataset.from_tensor_slices(labels)\n  dataset = tf.compat.v1.data.Dataset.zip((dataset1, dataset2))\n  if repeat:\n    dataset = dataset.repeat()\n  dataset = dataset.batch(batch_size, drop_remainder=True)\n\n  def _map(x, y):\n    return {'x': x}, y\n\n  return dataset.map(_map)\n\n\ndef dummy_input_fn(batch_size, repeat=True):\n  dataset = dummy_input_fn_with_dataset(batch_size, repeat)\n  iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)\n  return iterator.get_next()\n\n\ndef create_run_config(iterations_per_loop, **kwargs):\n  return tpu_config.RunConfig(\n      master='',\n      tpu_config=tpu_config.TPUConfig(\n          iterations_per_loop=iterations_per_loop,\n          num_shards=FLAGS.test_num_shards,\n          **kwargs),\n  )\n\n\nclass TPUEstimatorEvaluationTest(tf.test.TestCase):\n\n  def _create_input_fn(self):\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n    return _input_fn\n\n  def _create_head(self, mode, loss, eval_metrics):\n    \"\"\"Creates a head returning `TPUEstimatorSpec` based on mode.\"\"\"\n    if mode == _EVAL:\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode, eval_metrics=eval_metrics, loss=loss)\n    # Train\n    optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n        tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.5))\n    train_op = optimizer.minimize(loss,\n                                  global_step=tf.compat.v1.train.get_global_step())\n    return tpu_estimator.TPUEstimatorSpec(\n        mode=mode, train_op=train_op, loss=loss)\n\n  def _create_head_with_eval_metric_ops(self, mode, loss, eval_metric_ops):\n    \"\"\"Creates a head returning `TPUEstimatorSpec` based on mode.\n\n    This version contains eval that will not run on TPUs, where eval_metric_ops\n    has not been split into a metrics_fn that runs on CPUs. The intent is to\n    test the entire eval (model_fn forward pass) and metrics output on CPU.\n\n    Args:\n      mode: The mode such as TRAIN, EVAL.\n      loss: Training loss `Tensor`. Must be either scalar, or with shape `[1]`.\n      eval_metric_ops: Dict of metric results keyed by name.\n\n    Returns:\n      An EstimatorSpec for EVAL or TPUEstimatorSpec otherwise.\n    \"\"\"\n    if mode == _EVAL:\n      return model_fn_lib.EstimatorSpec(\n          mode=mode, eval_metric_ops=eval_metric_ops, loss=loss)\n    # Train\n    optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n        tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.5))\n    train_op = optimizer.minimize(loss,\n                                  global_step=tf.compat.v1.train.get_global_step())\n    return tpu_estimator.TPUEstimatorSpec(\n        mode=mode, train_op=train_op, loss=loss)\n\n  def _metric_fn_on_cpu(self, labels, predictions):\n    return {\n        'mse': tf.compat.v1.metrics.mean_absolute_error(labels, predictions),\n    }\n\n  def _model_fn_without_eval_metrics(self, features, labels, mode, params):\n    del params  # unused.\n    predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n        features['x'], 1,\n        kernel_initializer=tf.compat.v1.zeros_initializer())\n    loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n\n    return self._create_head(mode, loss, None)\n\n  def _model_fn_with_eval_tensor_list(self, features, labels, mode, params):\n    del params  # unused.\n    predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n        features['x'], 1,\n        kernel_initializer=tf.compat.v1.zeros_initializer())\n    loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n\n    return self._create_head(\n        mode, loss,\n        eval_metrics=(self._metric_fn_on_cpu, [labels, predictions]))\n\n  def _model_fn_with_eval_dict(self, features, labels, mode, params):\n    del params  # unused.\n    predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n        features['x'], 1,\n        kernel_initializer=tf.compat.v1.zeros_initializer())\n    loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n\n    return self._create_head(\n        mode, loss,\n        eval_metrics=(self._metric_fn_on_cpu, {\n            'labels': labels,\n            'predictions': predictions}))\n\n  def _model_fn_with_eval_metric_ops(self, features, labels, mode, params):\n    del params  # unused.\n    predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n        features['x'], 1,\n        kernel_initializer=tf.compat.v1.zeros_initializer())\n    loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n\n    eval_metric_ops = self._metric_fn_on_cpu(labels, predictions)\n    return self._create_head_with_eval_metric_ops(\n        mode, loss, eval_metric_ops)\n\n  def _test_eval_steps(self, model_fn, expected_eval_steps, iterations):\n\n    run_config = create_run_config(iterations_per_loop=iterations)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    est.train(self._create_input_fn(), steps=1)\n\n    class _EvalStepCheckHook(tf.compat.v1.train.SessionRunHook):\n      \"\"\"Check eval step counter after one session.run.\n\n      As the evaluation sets the eval iterations as the eval steps, the\n      after_run should be invoked only once.\n      \"\"\"\n\n      def __init__(self, iterations_per_loop, test_case):\n        \"\"\"Constructs the run hook.\"\"\"\n        self._iterations = iterations_per_loop\n        self._invoked = False\n        self._test_case = test_case\n\n      def before_run(self, run_context):\n        del run_context\n        # For eval on TPU, the hook should be run only once.\n        self._test_case.assertFalse(self._invoked)\n\n      def after_run(self, run_context, run_values):\n        # To avoid race condition between the eval step read and increment in\n        # evaluation graph, we read the value explicitly here.\n        eval_steps = run_context.session.run(\n            evaluation._get_or_create_eval_step())\n        self._test_case.assertEqual(expected_eval_steps, eval_steps)\n        self._test_case.assertFalse(self._invoked)\n        self._invoked = True\n\n    est.evaluate(self._create_input_fn(),\n                 steps=expected_eval_steps,\n                 hooks=[_EvalStepCheckHook(iterations, self)])\n\n  def test_no_eval_metrics(self):\n    run_config = create_run_config(iterations_per_loop=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_without_eval_metrics,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_steps_not_effected_by_training_iterations(self):\n    self._test_eval_steps(\n        model_fn=self._model_fn_with_eval_tensor_list,\n        expected_eval_steps=2,\n        iterations=4)\n    self._test_eval_steps(\n        model_fn=self._model_fn_with_eval_tensor_list,\n        expected_eval_steps=6,\n        iterations=4)\n\n  def test_eval_steps_with_no_eval_metrics(self):\n    self._test_eval_steps(\n        model_fn=self._model_fn_without_eval_metrics,\n        expected_eval_steps=6,\n        iterations=1)\n\n  def test_eval_metrics_with_tensor_list(self):\n    run_config = create_run_config(iterations_per_loop=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_tensor_list,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_batch_size_with_non_divisible_num_shards_broadcast_mode(self):\n    run_config = create_run_config(\n        iterations_per_loop=2,\n        per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_tensor_list,\n        config=run_config,\n        train_batch_size=7,\n        eval_batch_size=7)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_metrics_with_tensor_list_on_cpu(self):\n    run_config = create_run_config(iterations_per_loop=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_tensor_list,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16,\n        use_tpu=False)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_metrics_with_dict(self):\n    run_config = create_run_config(iterations_per_loop=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_dict,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_metrics_with_dict_on_cpu(self):\n    run_config = create_run_config(iterations_per_loop=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_dict,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16,\n        use_tpu=False)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_metrics_ops_cpu_training(self):\n    run_config = create_run_config(iterations_per_loop=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_metric_ops,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16,\n        use_tpu=False,\n        eval_on_tpu=False)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_metrics_ops_cpu_training_warning(self):\n    run_config = create_run_config(iterations_per_loop=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_metric_ops,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16,\n        use_tpu=False,\n        # eval_on_tpu is ignored if use_tpu is False\n        eval_on_tpu=True)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_metrics_ops_tpu_training(self):\n    run_config = create_run_config(iterations_per_loop=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_metric_ops,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16,\n        use_tpu=True,\n        eval_on_tpu=False)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_metrics_ops_tpu_training_failure(self):\n    run_config = create_run_config(iterations_per_loop=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_metric_ops,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16,\n        use_tpu=True,\n        # Generates an error on eval, because model_fn(mode=EVAL)\n        # has not been split into an eval_metrics_fn.\n        eval_on_tpu=True)\n\n    est.train(self._create_input_fn(), steps=1)\n    with self.assertRaisesRegex(\n        RuntimeError, 'TPU evaluation must have type`TPUEstimatorSpec`'):\n      est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_error_out_if_steps_is_float(self):\n    with self.assertRaisesRegex(TypeError, 'must be int'):\n      run_config = create_run_config(iterations_per_loop=2)\n      est = tpu_estimator.TPUEstimator(\n          model_fn=self._model_fn_with_eval_dict,\n          config=run_config,\n          train_batch_size=16,\n          eval_batch_size=16,\n          use_tpu=True)\n      est.evaluate(self._create_input_fn(), steps=12.3)\n\n  def test_error_out_if_steps_is_invalid(self):\n    with self.assertRaisesRegex(ValueError, 'must be positive'):\n      run_config = create_run_config(iterations_per_loop=2)\n      est = tpu_estimator.TPUEstimator(\n          model_fn=self._model_fn_with_eval_dict,\n          config=run_config,\n          train_batch_size=16,\n          eval_batch_size=16,\n          use_tpu=True)\n      est.evaluate(self._create_input_fn(), steps=-321)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator_export_test.py",
    "content": "# Copyright 2021 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for TPUEstimator export related functionalities.\"\"\"\n\nfrom absl import flags\nfrom absl.testing import parameterized\nimport numpy as np\nimport os\nimport tempfile\nimport tensorflow.compat.v1 as tf\n\n# pylint: disable=g-direct-tensorflow-import\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.python import data as dataset_lib\nfrom tensorflow.python.client import session\nfrom tensorflow.python.data.ops import dataset_ops\nfrom tensorflow.python.framework import dtypes\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.ops import array_ops\nfrom tensorflow.python.ops import init_ops\nfrom tensorflow.python.ops import parsing_ops\nfrom tensorflow.python.ops import string_ops\nfrom tensorflow.python.ops.losses import losses\nfrom tensorflow.python.platform import gfile\nfrom tensorflow.python.platform import test\nfrom tensorflow.python.saved_model import loader\nfrom tensorflow.python.saved_model import loader_impl\nfrom tensorflow.python.saved_model import signature_constants\nfrom tensorflow.python.saved_model import tag_constants\nfrom tensorflow.python.training import training\nfrom tensorflow.python.util import compat\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.export import export_lib\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\n# pylint: enable=g-direct-tensorflow-import\n\nflags.DEFINE_integer('test_num_shards', 8, 'number of replicas to test')\n\n\nFLAGS = flags.FLAGS\n\n_TRAIN = model_fn_lib.ModeKeys.TRAIN\n_EVAL = model_fn_lib.ModeKeys.EVAL\n_PREDICT = model_fn_lib.ModeKeys.PREDICT\n\n_PER_HOST = 'per_host_sharding'\n_PER_SHARD = 'per_shard_sharding'\n_UNSHARDED = 'unsharded'\n_INPUT_PIPELINE_WITH_QUEUE_RUNNER = (\n    'Input pipeline contains one or more QueueRunners')\n\n\ndef dense_computation(features):\n  return tf_keras_v1.__internal__.legacy.layers.dense(\n      features['x'], 1, kernel_initializer=init_ops.zeros_initializer())\n\n\ndef get_model_fn(export_tpu_tensor=True, export_cpu_tensor=False,\n                 tpu_estimator_spec=True):\n\n  def model_fn(features, labels, mode, params):\n    del params\n    loss = None\n    train_op = None\n    predictions = dense_computation(features)\n    export_outputs = None\n    if mode != _PREDICT:\n      loss = losses.mean_squared_error(labels, predictions)\n      optimizer = tf.tpu.CrossShardOptimizer(\n          training.GradientDescentOptimizer(learning_rate=0.5))\n      train_op = optimizer.minimize(loss, training.get_global_step())\n    else:\n      if export_tpu_tensor:\n        key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n        export_outputs = {\n            key: export_lib.PredictOutput({\n                'prediction': predictions\n            })\n        }\n      else:\n        export_outputs = {}\n\n      if export_cpu_tensor:\n\n        def host_call(predictions):\n          return string_ops.as_string(predictions, name='classes')\n\n        classes = tf.tpu.outside_compilation(host_call, predictions)\n        classification_output = export_lib.ClassificationOutput(\n            classes=classes)\n        export_outputs['classification'] = classification_output\n\n    if tpu_estimator_spec:\n      spec_type = tpu_estimator.TPUEstimatorSpec\n    else:\n      spec_type = model_fn_lib.EstimatorSpec\n\n    return spec_type(\n        mode,\n        loss=loss,\n        train_op=train_op,\n        predictions={'predictions': predictions},\n        export_outputs=export_outputs)\n\n  return model_fn\n\n\ndef dummy_input_fn_with_dataset(batch_size, repeat=True, x=None):\n  if x is None:\n    x = np.random.normal(size=[batch_size, 1]).astype(np.float32)\n  labels = [[2.0]] * batch_size\n\n  dataset1 = dataset_lib.Dataset.from_tensor_slices(x)\n  dataset2 = dataset_lib.Dataset.from_tensor_slices(labels)\n  dataset = dataset_lib.Dataset.zip((dataset1, dataset2))\n  if repeat:\n    dataset = dataset.repeat()\n  dataset = dataset.batch(batch_size, drop_remainder=True)\n\n  def _map(x, y):\n    return {'x': x}, y\n\n  return dataset.map(_map)\n\n\ndef dummy_input_fn(batch_size, repeat=True):\n  dataset = dummy_input_fn_with_dataset(batch_size, repeat)\n  iterator = dataset_ops.make_one_shot_iterator(dataset)\n  return iterator.get_next()\n\n\ndef create_run_config(iterations_per_loop, **kwargs):\n  if 'num_shards' not in kwargs:\n    kwargs['num_shards'] = FLAGS.test_num_shards\n  return tpu_config.RunConfig(\n      master='',\n      tpu_config=tpu_config.TPUConfig(\n          iterations_per_loop=iterations_per_loop, **kwargs),\n  )\n\n\nclass TPUEstimatorExportTest(parameterized.TestCase):\n\n  def setUp(self):\n    super().setUp()\n    feature_spec = {'x': parsing_ops.FixedLenFeature([1], dtypes.float32)}\n    self._serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n\n    feature_spec = {\n        'x':\n            array_ops.placeholder(dtype=dtypes.float32, shape=(2, 1), name='x'),\n    }\n    label_spec = array_ops.placeholder(\n        dtype=dtypes.float32, shape=(1, 1), name='truth')\n    self._supervised_input_receiver_fn = (\n        export_lib.build_raw_supervised_input_receiver_fn(\n            feature_spec, label_spec))\n\n  @parameterized.parameters(\n      (True, False, False),\n      (True, True, False),\n      (False, True, False),\n      (True, False, True),\n      (True, True, True),\n      (False, True, True))\n  def test_export_tpu_savedmodel_e2e(self, export_tpu_tensor, export_cpu_tensor,\n                                     use_export_mode_v2):\n    tmpdir = tempfile.mkdtemp()\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    model_fn = get_model_fn(export_tpu_tensor, export_cpu_tensor)\n    run_config = create_run_config(iterations_per_loop=4)\n    if use_export_mode_v2:\n      export_api_version = tpu_estimator.ExportSavedModelApiVersion.V2\n\n      batch_config = tpu_estimator.BatchConfig(\n          num_batch_threads=1,\n          max_batch_size=1,\n          batch_timeout_micros=100,\n          allowed_batch_sizes=[1])\n\n      def tpu_model_fn(features, labels, mode, params):\n        if mode == _PREDICT and params['use_tpu']:\n          return tpu_estimator.model_fn_inference_on_tpu(\n              model_fn, features, labels, mode, params, batch_config)\n        else:\n          return model_fn(features, labels, mode, params)\n\n      est_model_fn = tpu_model_fn\n    else:\n      export_api_version = tpu_estimator.ExportSavedModelApiVersion.V1\n      est_model_fn = model_fn\n    est = tpu_estimator.TPUEstimator(\n        model_fn=est_model_fn,\n        config=run_config,\n        train_batch_size=16,\n        export_to_tpu=True,\n        export_saved_model_api_version=export_api_version)\n    est.train(_input_fn, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        compat.as_bytes(tmpdir), compat.as_bytes('export'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        self._serving_input_receiver_fn)\n\n    self._validate_export(export_dir_base, export_dir, export_tpu_tensor,\n                          export_cpu_tensor)\n\n    # Clean up.\n    gfile.DeleteRecursively(tmpdir)\n\n  def _validate_export(self, export_dir_base, export_dir, export_tpu_tensor,\n                       export_cpu_tensor):\n    # Check that all the files are in the right places.\n    self.assertTrue(gfile.Exists(export_dir_base))\n    self.assertTrue(gfile.Exists(export_dir))\n    self.assertTrue(\n        gfile.Exists(\n            os.path.join(\n                compat.as_bytes(export_dir),\n                compat.as_bytes('saved_model.pb'))))\n    self.assertTrue(\n        gfile.Exists(\n            os.path.join(\n                compat.as_bytes(export_dir), compat.as_bytes('variables'))))\n    self.assertTrue(\n        gfile.Exists(\n            os.path.join(\n                compat.as_bytes(export_dir),\n                compat.as_bytes('variables/variables.index'))))\n    self.assertTrue(\n        gfile.Exists(\n            os.path.join(\n                compat.as_bytes(export_dir),\n                compat.as_bytes('variables/variables.data-00000-of-00001'))))\n\n    def session_run():\n      example = example_pb2.Example()\n      example.features.feature['x'].float_list.value.append(1)\n\n      tensor_name_prediction = None\n      tensor_name_classes = None\n      if export_tpu_tensor:\n        key_prediction = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n        tensor_name_prediction = (\n            meta_graph.signature_def[key_prediction].\n            outputs['prediction'].name)\n        tensor_name_input = (meta_graph.signature_def[key_prediction].\n                             inputs['examples'].name)\n\n      if export_cpu_tensor:\n        key_classification = 'classification'\n        tensor_name_classes = (meta_graph.signature_def[key_classification].\n                               outputs['classes'].name)\n        tensor_name_input = (meta_graph.signature_def[key_classification].\n                             inputs['inputs'].name)\n\n      if export_tpu_tensor:\n        sess.run(\n            tensor_name_prediction,\n            feed_dict={tensor_name_input: [example.SerializeToString()]})\n      if export_cpu_tensor:\n        sess.run(\n            tensor_name_classes,\n            feed_dict={tensor_name_input: [example.SerializeToString()]})\n      if export_cpu_tensor and export_tpu_tensor:\n        sess.run(\n            [tensor_name_prediction, tensor_name_classes],\n            feed_dict={tensor_name_input: [example.SerializeToString()]})\n\n    # Restore, to validate that the export was well-formed.\n    with ops.Graph().as_default() as graph:\n      with session.Session(graph=graph) as sess:\n        meta_graph = loader.load(\n            sess, [tag_constants.SERVING, tag_constants.TPU], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertIn('input_example_tensor', graph_ops)\n        self.assertIn('ParseExample/ParseExampleV2', graph_ops)\n        self.assertNotIn('dense/kernel/GuaranteeConst', graph_ops)\n\n        sess.run(tf.tpu.initialize_system())\n        session_run()\n\n    # Restore, to validate that the export was well-formed.\n    with ops.Graph().as_default() as graph:\n      with session.Session(graph=graph) as sess:\n        meta_graph = loader.load(sess, [tag_constants.SERVING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertIn('input_example_tensor', graph_ops)\n        self.assertIn('ParseExample/ParseExampleV2', graph_ops)\n        self.assertIn('dense/kernel', graph_ops)\n        # GuaranteeConst ops won't be present in the CPU-only graph.\n        self.assertNotIn('dense/kernel/GuaranteeConst', graph_ops)\n\n        session_run()\n\n  def test_export_tpu_savedmodel_export_to_cpu_false(self):\n    # Test that when `export_to_cpu` is `False`, CPU metagraph is not exported.\n    tmpdir = tempfile.mkdtemp()\n\n    model_fn = get_model_fn(export_tpu_tensor=True,\n                            export_cpu_tensor=True)\n    run_config = create_run_config(iterations_per_loop=4)\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn, config=run_config, train_batch_size=16,\n        export_to_tpu=True, export_to_cpu=False)\n    est.train(_input_fn, steps=1)\n\n    export_dir_base = os.path.join(\n        compat.as_bytes(tmpdir), compat.as_bytes('export_no_tpu'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        self._serving_input_receiver_fn)\n    saved_model = loader_impl.parse_saved_model(export_dir)\n    self.assertLen(saved_model.meta_graphs, 1)\n    tags = set(saved_model.meta_graphs[0].meta_info_def.tags)\n    self.assertEqual(tags, set([tag_constants.SERVING, tag_constants.TPU]))\n\n    # Clean up.\n    gfile.DeleteRecursively(tmpdir)\n\n  def test_export_tpu_savedmodel_export_to_tpu_false(self):\n    # Test that when `export_to_tpu` is `False`, TPU metagraph is not exported.\n    tmpdir = tempfile.mkdtemp()\n\n    model_fn = get_model_fn(export_tpu_tensor=True,\n                            export_cpu_tensor=True)\n    run_config = create_run_config(iterations_per_loop=4)\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn, config=run_config, train_batch_size=16,\n        export_to_tpu=False)\n    est.train(_input_fn, steps=1)\n\n    export_dir_base = os.path.join(\n        compat.as_bytes(tmpdir), compat.as_bytes('export_no_tpu'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        self._serving_input_receiver_fn)\n    with ops.Graph().as_default() as graph:\n      with session.Session(graph=graph) as sess:\n        with self.assertRaisesRegex(\n            RuntimeError,\n            'MetaGraphDef associated with tags \\'serve\\', \\'tpu\\' could not be '\n            'found in SavedModel.'):\n          loader.load(\n              sess, [tag_constants.SERVING, tag_constants.TPU], export_dir)\n        loader.load(\n            sess, [tag_constants.SERVING], export_dir)\n\n    # Clean up.\n    gfile.DeleteRecursively(tmpdir)\n\n  def test_export_tpu_savedmodel_export_to_tpu_false_eval(self):\n    # Test exporting CPU evaulation graph when `export_to_tpu` is `False`.\n    tmpdir = tempfile.mkdtemp()\n    mode = model_fn_lib.ModeKeys.EVAL\n\n    model_fn = get_model_fn(export_tpu_tensor=True, export_cpu_tensor=True)\n    run_config = create_run_config(iterations_per_loop=4)\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn,\n        config=run_config,\n        train_batch_size=16,\n        export_to_tpu=False)\n    est.train(_input_fn, steps=1)\n\n    export_dir_base = os.path.join(\n        compat.as_bytes(tmpdir), compat.as_bytes('export_no_tpu_eval'))\n    export_dir = est.export_saved_model(\n        export_dir_base, self._supervised_input_receiver_fn,\n        experimental_mode=mode)\n\n    # Check that all the files are in the right places.\n    self.assertTrue(gfile.Exists(export_dir_base))\n\n    # Restore, to validate that the export was well-formed.\n    tag_set = export_lib.EXPORT_TAG_MAP[mode]\n    with ops.Graph().as_default() as graph:\n      with session.Session(graph=graph) as sess:\n        loader.load(sess, tag_set, export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertIn('dense/kernel', graph_ops)\n\n    # Clean up.\n    gfile.DeleteRecursively(tmpdir)\n\n  def test_export_estimator_savedmodel(self):\n    export_tpu_tensor = True\n    export_cpu_tensor = False\n\n    tmpdir = tempfile.mkdtemp()\n\n    def _input_fn(params):\n      del params\n      # Estimator does not pass `batch_size` to `input_fn`.\n      return dummy_input_fn(batch_size=1)\n\n    model_fn = get_model_fn(export_tpu_tensor=export_tpu_tensor,\n                            export_cpu_tensor=export_cpu_tensor,\n                            tpu_estimator_spec=False)\n    est = estimator_lib.Estimator(model_fn=model_fn)\n    est.train(_input_fn, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        compat.as_bytes(tmpdir), compat.as_bytes('export'))\n    export_dir = tpu_estimator.export_estimator_savedmodel(\n        est,\n        export_dir_base,\n        self._serving_input_receiver_fn)\n\n    self._validate_export(export_dir_base, export_dir, export_tpu_tensor,\n                          export_cpu_tensor)\n\n    # Clean up.\n    gfile.DeleteRecursively(tmpdir)\n\n  def test_regression_output_tensors_roundtrip(self):\n    value = array_ops.placeholder(dtypes.float32, 1, name='value')\n    regression_output = export_lib.RegressionOutput(value)\n    self.assertSequenceEqual(\n        [value],\n        tpu_estimator._export_output_to_tensors(regression_output))\n\n    value_new = array_ops.placeholder(dtypes.float32, 1, name='value_new')\n    regression_output_new = (\n        tpu_estimator._clone_export_output_with_tensors(\n            regression_output, [value_new]\n        )\n    )\n    self.assertEqual(value_new, regression_output_new.value)\n\n  def test_predict_output_tensors_roundtrip(self):\n    value1 = array_ops.placeholder(dtypes.float32, 1, name='value1')\n    value2 = array_ops.placeholder(dtypes.float32, 1, name='value2')\n    predict_output = export_lib.PredictOutput({\n        'value1': value1,\n        'value2': value2\n    })\n    export_output_tensors = tpu_estimator._export_output_to_tensors(\n        predict_output)\n    self.assertSameElements([value1, value2], export_output_tensors)\n    self.assertLen(export_output_tensors, 2)\n\n    tensors_new = [\n        array_ops.identity(t, name=t.name.split(':')[0] + '_new')\n        for t in export_output_tensors\n    ]\n    predict_output_new = tpu_estimator._clone_export_output_with_tensors(\n        predict_output, tensors_new)\n    outputs = predict_output_new.outputs\n    self.assertLen(outputs, 2)\n    self.assertEqual(outputs['value1'].name, 'value1_new:0')\n    self.assertEqual(outputs['value2'].name, 'value2_new:0')\n\n  def test_classification_output_tensors_roundtrip_classes_only(self):\n    classes = array_ops.placeholder(dtypes.string, 1, name='classes')\n    classification_output = export_lib.ClassificationOutput(\n        classes=classes)\n\n    classification_output_tensors = (tpu_estimator.\n                                     _export_output_to_tensors(\n                                         classification_output))\n    self.assertEqual(classification_output_tensors, [None, classes])\n\n    classes_new = array_ops.placeholder(dtypes.string, 1, name='classes_new')\n    classification_output_new = (tpu_estimator.\n                                 _clone_export_output_with_tensors(\n                                     classification_output,\n                                     [None, classes_new]))\n    self.assertEqual(classification_output_new.classes, classes_new)\n\n  def test_classification_output_tensors_roundtrip_scores_only(self):\n    scores = array_ops.placeholder(dtypes.float32, 1, name='scores')\n    classification_output = export_lib.ClassificationOutput(\n        scores=scores)\n\n    classification_output_tensors = (tpu_estimator.\n                                     _export_output_to_tensors(\n                                         classification_output))\n    self.assertEqual(classification_output_tensors, [scores, None])\n\n    scores_new = array_ops.placeholder(dtypes.float32, 1, name='scores_new')\n    classification_output_new = (tpu_estimator.\n                                 _clone_export_output_with_tensors(\n                                     classification_output, [scores_new, None]))\n    self.assertEqual(classification_output_new.scores, scores_new)\n\n  def test_classification_output_tensors_roundtrip_classify_both(self):\n    classes = array_ops.placeholder(dtypes.string, 1, name='classes')\n    scores = array_ops.placeholder(dtypes.float32, 1, name='scores')\n    classification_output = export_lib.ClassificationOutput(\n        scores, classes)\n\n    classification_output_tensors = (tpu_estimator.\n                                     _export_output_to_tensors(\n                                         classification_output))\n    self.assertSequenceEqual(classification_output_tensors, [scores, classes])\n\n    classes_new = array_ops.placeholder(dtypes.string, 1, name='classes_new')\n    scores_new = array_ops.placeholder(dtypes.float32, 1, name='scores_new')\n    classification_output_new = (tpu_estimator.\n                                 _clone_export_output_with_tensors(\n                                     classification_output,\n                                     [scores_new, classes_new]))\n    self.assertEqual(classification_output_new.classes, classes_new)\n    self.assertEqual(classification_output_new.scores, scores_new)\n\n\ndef get_model_fn_v2():\n\n  def model_fn(features, labels, mode, params):\n    loss = None\n    train_op = None\n    export_outputs = None\n\n    # This could be some pre-processing on CPU like calls to input layer with\n    # embedding columns.\n    x2 = features['x'] * 2\n\n    def computation(input_tensor):\n      return tf_keras_v1.__internal__.legacy.layers.dense(\n          input_tensor, 1, kernel_initializer=init_ops.zeros_initializer())\n\n    if mode != _PREDICT:\n      predictions = computation(x2)\n      loss = losses.mean_squared_error(labels, predictions)\n      optimizer = tf.tpu.CrossShardOptimizer(\n          training.GradientDescentOptimizer(learning_rate=0.5))\n      train_op = optimizer.minimize(loss, training.get_global_step())\n    else:\n      inputs = [x2]\n      if params['use_tpu']:\n        predictions = array_ops.identity(\n            tpu_estimator.inference_on_tpu(\n                computation, inputs, num_batch_threads=1, max_batch_size=2,\n                batch_timeout_micros=100),\n            name='predictions')\n      else:\n        predictions = array_ops.identity(\n            computation(*inputs), name='predictions')\n      key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n      export_outputs = {\n          key: export_lib.PredictOutput({'prediction': predictions})\n      }\n\n      classes = string_ops.as_string(predictions, name='classes')\n      classification_output = export_lib.ClassificationOutput(classes=classes)\n      export_outputs['classification'] = classification_output\n\n    return tpu_estimator.TPUEstimatorSpec(\n        mode,\n        loss=loss,\n        train_op=train_op,\n        predictions={'predictions': predictions},\n        export_outputs=export_outputs)\n\n  return model_fn\n\n\nclass TPUEstimatorExportV2Test(parameterized.TestCase):\n\n  def setUp(self):\n    super().setUp()\n    feature_spec = {'x': parsing_ops.FixedLenFeature([1], dtypes.float32)}\n    self._serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n\n  def test_export_tpu_savedmodel_e2e(self):\n    tmpdir = tempfile.mkdtemp()\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    model_fn = get_model_fn_v2()\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn,\n        config=run_config,\n        train_batch_size=16,\n        export_to_tpu=True,\n        export_saved_model_api_version=tpu_estimator.ExportSavedModelApiVersion\n        .V2)\n    est.train(_input_fn, steps=1)\n\n    # Perform the export.\n    export_dir_base = os.path.join(\n        compat.as_bytes(tmpdir), compat.as_bytes('export'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        self._serving_input_receiver_fn)\n\n    self._validate_export(export_dir_base, export_dir)\n\n    # Clean up.\n    gfile.DeleteRecursively(tmpdir)\n\n  def _validate_export(self, export_dir_base, export_dir):\n    # Check that all the files are in the right places.\n    self.assertTrue(gfile.Exists(export_dir_base))\n    self.assertTrue(gfile.Exists(export_dir))\n    self.assertTrue(\n        gfile.Exists(\n            os.path.join(\n                compat.as_bytes(export_dir),\n                compat.as_bytes('saved_model.pb'))))\n    self.assertTrue(\n        gfile.Exists(\n            os.path.join(\n                compat.as_bytes(export_dir), compat.as_bytes('variables'))))\n    self.assertTrue(\n        gfile.Exists(\n            os.path.join(\n                compat.as_bytes(export_dir),\n                compat.as_bytes('variables/variables.index'))))\n    self.assertTrue(\n        gfile.Exists(\n            os.path.join(\n                compat.as_bytes(export_dir),\n                compat.as_bytes('variables/variables.data-00000-of-00001'))))\n\n    def session_run():\n      example = example_pb2.Example()\n      example.features.feature['x'].float_list.value.append(1)\n\n      tensor_name_prediction = None\n      tensor_name_classes = None\n      key_prediction = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n      tensor_name_prediction = (\n          meta_graph.signature_def[key_prediction].outputs['prediction'].name)\n      key_classification = 'classification'\n      tensor_name_classes = (\n          meta_graph.signature_def[key_classification].outputs['classes'].name)\n\n      sess.run(\n          tensor_name_prediction,\n          feed_dict={'input_example_tensor:0': [example.SerializeToString()]})\n      sess.run(\n          tensor_name_classes,\n          feed_dict={'input_example_tensor:0': [example.SerializeToString()]})\n      sess.run(\n          [tensor_name_prediction, tensor_name_classes],\n          feed_dict={'input_example_tensor:0': [example.SerializeToString()]})\n\n    # Restore, to validate that the export was well-formed.\n    with ops.Graph().as_default() as graph:\n      with session.Session(graph=graph) as sess:\n        meta_graph = loader.load(sess,\n                                 [tag_constants.SERVING, tag_constants.TPU],\n                                 export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertIn('input_example_tensor', graph_ops)\n        self.assertIn('ParseExample/ParseExampleV2', graph_ops)\n        self.assertNotIn('dense/kernel/GuaranteeConst', graph_ops)\n        self.assertIn('batch/BatchFunction', graph_ops)\n\n        sess.run(tf.tpu.initialize_system())\n        session_run()\n\n    # Restore, to validate that the export was well-formed.\n    with ops.Graph().as_default() as graph:\n      with session.Session(graph=graph) as sess:\n        meta_graph = loader.load(sess, [tag_constants.SERVING], export_dir)\n        graph_ops = [x.name for x in graph.get_operations()]\n        self.assertIn('input_example_tensor', graph_ops)\n        self.assertIn('ParseExample/ParseExampleV2', graph_ops)\n        self.assertIn('dense/kernel', graph_ops)\n        # GuaranteeConst ops won't be present in the CPU-only graph.\n        self.assertNotIn('dense/kernel/GuaranteeConst', graph_ops)\n\n        session_run()\n\n  def test_export_tpu_savedmodel_export_to_tpu_false(self):\n    # Test that when `export_to_tpu` is `False`, TPU metagraph is not exported.\n    tmpdir = tempfile.mkdtemp()\n\n    model_fn = get_model_fn_v2()\n    run_config = create_run_config(iterations_per_loop=4)\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn,\n        config=run_config,\n        train_batch_size=16,\n        export_to_tpu=False,\n        export_saved_model_api_version=tpu_estimator.ExportSavedModelApiVersion\n        .V2)\n    est.train(_input_fn, steps=1)\n\n    export_dir_base = os.path.join(\n        compat.as_bytes(tmpdir), compat.as_bytes('export_no_tpu'))\n    export_dir = est.export_saved_model(export_dir_base,\n                                        self._serving_input_receiver_fn)\n    with ops.Graph().as_default() as graph:\n      with session.Session(graph=graph) as sess:\n        with self.assertRaisesRegex(\n            RuntimeError,\n            'MetaGraphDef associated with tags \\'serve\\', \\'tpu\\' could not be '\n            'found in SavedModel.'):\n          loader.load(sess, [tag_constants.SERVING, tag_constants.TPU],\n                      export_dir)\n        loader.load(sess, [tag_constants.SERVING], export_dir)\n\n    # Clean up.\n    gfile.DeleteRecursively(tmpdir)\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator_gradients_test.py",
    "content": "# Copyright 2021 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests to check gradients of TPUEstimator + TPU Embeddings.\"\"\"\n\nimport math\nimport tempfile\nfrom absl import flags\nimport numpy as np\nimport tensorflow.compat.v1 as tf\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\n\nflags.DEFINE_integer('test_num_shards', 2, 'number of replicas to test')\n\nFLAGS = flags.FLAGS\n\nLEARNING_RATE = 0.12\nHIDDEN_LAYER_SIZE = 20\nKERNEL_INIT_VALUE = 0.1\nBIAS_INIT_VALUE = 0.2\nADADGRAD_INIT_VALUE = 0.1\nBUCKET_SIZE = 8\nEMBEDDING_DIM = 3\nKEY_NAME = 'x'\nGRAD_MULTIPLIER = 1000.\n\nBIAS_VAR = 'dense/bias:0'\nCPU_EMBEDDING_VAR = 'dense_features/x_embedding/embedding_weights:0'\nCPU_EMBEDDING_ACCUM_VAR = 'dense_features/x_embedding/embedding_weights/Adagrad:0'\nTPU_EMBEDDING_VAR = 'dense_features/x_embedding/embedding_weights/part_0:0'\nTPU_EMBEDDING_ACCUM_VAR = 'dense_features/x_embedding/embedding_weights/Adagrad/part_0:0'\n\n# This test must be running with \"--xla_jf_conv_full_precision=true\",\nDEFAULT_TOL = 1e-6\n\n\ndef create_model_fn(feature_columns, optimizer_type='adagrad'):\n\n  def model_fn(features, labels, mode, params):\n    del params\n\n    dense_features = tf_keras_v1.layers.DenseFeatures(feature_columns)\n    input_layer = dense_features(features)\n    hidden_layer = tf_keras_v1.__internal__.legacy.layers.dense(\n        input_layer,\n        HIDDEN_LAYER_SIZE,\n        kernel_initializer=tf.constant_initializer(KERNEL_INIT_VALUE),\n        bias_initializer=tf.constant_initializer(BIAS_INIT_VALUE))\n\n    last_layer = tf.reduce_sum(hidden_layer, axis=1)\n\n    logits = tf.reshape(last_layer, [-1])\n    labels = tf.reshape(labels, [-1])\n    losses = tf.square(labels - logits)\n\n    # Use reduce_mean to match the CrossShardOptimizer reduction.\n    loss = tf.reduce_mean(losses)\n    if optimizer_type == 'adagrad':\n      optimizer = tf.train.AdagradOptimizer(\n          LEARNING_RATE, initial_accumulator_value=ADADGRAD_INIT_VALUE)\n    elif optimizer_type == 'sgd':\n      optimizer = tf.train.GradientDescentOptimizer(LEARNING_RATE)\n    else:\n      raise ValueError('{} is not supported.'.format(optimizer_type))\n    # Default reduction=tf.losses.Reduction.MEAN\n    optimizer = tf.tpu.CrossShardOptimizer(optimizer)\n\n    train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())\n    return tpu_estimator.TPUEstimatorSpec(mode=mode, loss=loss, train_op=train_op)\n\n  return model_fn\n\n\ndef get_estimator(use_tpu,\n                  output_dir,\n                  feature_columns,\n                  batch_size,\n                  optimizer_type='adagrad',\n                  grad_multiplier_fn=None):\n  run_config = tpu_config.RunConfig(\n      master='',\n      model_dir=output_dir,\n      session_config=tf.ConfigProto(\n          allow_soft_placement=True, log_device_placement=False),\n      tpu_config=tpu_config.TPUConfig(\n          iterations_per_loop=1,\n          num_shards=FLAGS.test_num_shards,\n          per_host_input_for_training=(\n              tpu_config.InputPipelineConfig.PER_HOST_V2)),\n      save_checkpoints_steps=1)\n\n  if optimizer_type == 'adagrad':\n    optimization_parameters = tpu_estimator.AdagradParameters(\n        LEARNING_RATE,\n        ADADGRAD_INIT_VALUE,\n        use_gradient_accumulation=False)\n  elif optimizer_type == 'sgd':\n    optimization_parameters = tpu_estimator.StochasticGradientDescentParameters(\n        LEARNING_RATE)\n\n  estimator = tpu_estimator.TPUEstimator(\n      model_fn=create_model_fn(feature_columns, optimizer_type),\n      use_tpu=use_tpu,\n      config=run_config,\n      train_batch_size=batch_size,\n      eval_batch_size=batch_size,\n      embedding_config_spec=tpu_estimator.EmbeddingConfigSpec(\n          feature_columns=feature_columns,\n          optimization_parameters=optimization_parameters,\n          experimental_gradient_multiplier_fn=grad_multiplier_fn))\n  return estimator\n\n\ndef get_feature_columns():\n  initializer = tf.zeros_initializer()\n\n  column = tf.feature_column.categorical_column_with_identity(\n      key=KEY_NAME, num_buckets=BUCKET_SIZE)\n  embedding_fc = tf.tpu.experimental.embedding_column(\n      column,\n      dimension=EMBEDDING_DIM,\n      combiner='mean',\n      initializer=initializer)\n\n  all_fc = [embedding_fc]\n  return all_fc\n\n\nclass _EmbeddingVariableHook(tf.train.SessionRunHook):\n  \"\"\"A hook to record the embedding variable.\"\"\"\n\n  def __init__(self, use_tpu, include_slot_vars=True):\n    self._use_tpu = use_tpu\n    self._include_slot_vars = include_slot_vars\n\n  def _set_bias_var(self):\n    self._bias_var = [v for v in tf.trainable_variables() if v.name == BIAS_VAR]\n\n  def begin(self):\n    search_var = TPU_EMBEDDING_VAR if self._use_tpu else CPU_EMBEDDING_VAR\n    self._var = [v for v in tf.global_variables() if v.name == search_var][0]\n    if self._include_slot_vars:\n      search_accum_var = TPU_EMBEDDING_ACCUM_VAR if self._use_tpu else CPU_EMBEDDING_ACCUM_VAR\n      self._slot_var = [\n          v for v in tf.global_variables() if v.name == search_accum_var\n      ][0]\n    self._set_bias_var()\n\n    self.bias_values = []\n    self.var_values = []\n    self.slot_var_values = []\n\n  def after_create_session(self, session, coord):\n    del coord\n    self.bias_values.append(session.run(self._bias_var))\n    self.var_values.append(session.run(self._var))\n    if self._include_slot_vars:\n      self.slot_var_values.append(session.run(self._slot_var))\n\n  def after_run(self, run_context, run_values):\n    self.bias_values.append(run_context.session.run(self._bias_var))\n    self.var_values.append(run_context.session.run(self._var))\n    if self._include_slot_vars:\n      self.slot_var_values.append(run_context.session.run(self._slot_var))\n\n\ndef get_activation_gradients(label):\n  \"\"\"Gets the sample gradient w.r.t activation according to the model_fn.\"\"\"\n  # The sample loss is (label - logits)**2, where\n  #     logits = \\sum_j^HIDDEN_LAYER_SIZE (\n  #         \\sum_i^EMBEDDING_DIM w_i * kernel + bias)\n  #\n  # Note kernel and bias are both constant initializer in this test.\n  #\n  # So, gradients of loss w.r.t w_i is\n  #    grads = 2 * (label - logits) gradients( logits w.r.t. w_i)\n  #          = 2 * (label - logits) (-1 * HIDDEN_LAYER_SIZE * kernel)\n  #\n  # Given the weights are zero initializer,\n  #    grads = - 2 HIDDEN_LAYER_SIZE * kernel (label - HIDDEN_LAYER_SIZE * bias)\n\n  return -2 * HIDDEN_LAYER_SIZE * KERNEL_INIT_VALUE * (\n      label - HIDDEN_LAYER_SIZE * BIAS_INIT_VALUE)\n\n\ndef get_embedding_update(gradient, previous_accum_inc=0.0):\n  \"\"\"Gets the embedding update according to Adagrad.\n\n  Args:\n    gradient: the embedding gradient.\n    previous_accum_inc: The previous total accumulator increment (in addition to\n        the initialize value)\n\n  Returns:\n    the value to apply gradient.\n  \"\"\"\n  return -LEARNING_RATE * (\n      gradient /\n      math.sqrt(ADADGRAD_INIT_VALUE + previous_accum_inc + gradient**2))\n\n\ndef dense_to_sparse(dense_tensor, out_type, ignore_value=-1):\n  indices = tf.where(\n      tf.not_equal(dense_tensor,\n                   tf.constant(ignore_value, dense_tensor.dtype)))\n  values = tf.gather_nd(dense_tensor, indices)\n  shape = tf.shape(dense_tensor, out_type=out_type)\n  return tf.SparseTensor(indices, values, shape)\n\n\nclass TPUEstimatorGradientsSimpleTest(tf.test.TestCase):\n  \"\"\"Test gradients for different Ids in global batch.\n\n  In all examples examined by this test, in one global batch, each embedding ID\n  appears only once. So, we can expect the embedding variable and accumulate\n  variable will be same after one CPU training and TPU training.\n\n  For more complicated example, each ID can appear multiple times in one core\n  mini-batch and across multiple cores, see\n  TPUEstimatorGradientsWithIdCollisionTest.\n  \"\"\"\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def _input_fn(self, params):\n    # This input_fn returns a tuple of sparse tensor and a dense tensor in\n    # sequence.\n    # sample 0:  sparse tensor value [0]  dense (target) [1]\n    # sample 1:  sparse tensor value [1]  dense (target) [2]\n    # sample 2:  sparse tensor value [2]  dense (target) [3]\n    # ...\n    batch_size = params['batch_size']\n\n    ds = tf.data.Dataset.range(8)\n\n    def _map_fn(index):\n      index = tf.reshape(index, [1])\n      dense_tensor = tf.cast(index + 1, tf.float32)\n      return ({KEY_NAME: dense_to_sparse(index, tf.int64)}, dense_tensor)\n\n    ds = ds.map(_map_fn)\n    ds = ds.batch(batch_size, drop_remainder=True)\n    return ds\n\n  def test_input_fn(self):\n    ds = self._input_fn({'batch_size': 1})\n    gn = ds.make_one_shot_iterator().get_next()\n\n    with tf.Session() as sess:\n      for i in range(8):\n        features, dense_tensor = sess.run(gn)\n        sparse_tensor = features[KEY_NAME]\n\n        self.assertAllEqual([[0, 0]], sparse_tensor.indices)\n        self.assertAllEqual([i], sparse_tensor.values)\n        self.assertAllEqual([[i + 1]], dense_tensor)\n\n    tf.reset_default_graph()\n\n    ds = self._input_fn({'batch_size': 2})\n    gn = ds.make_one_shot_iterator().get_next()\n\n    with tf.Session() as sess:\n      for i in range(4):\n        features, dense_tensor = sess.run(gn)\n        sparse_tensor = features[KEY_NAME]\n        self.assertAllEqual([[0, 0], [1, 0]], sparse_tensor.indices)\n        self.assertAllEqual([i * 2, i * 2 + 1], sparse_tensor.values)\n        self.assertAllEqual([[i * 2 + 1], [i * 2 + 2]], dense_tensor)\n\n  def assert_embedding_variables(self,\n                                 gradients_for_embedding,\n                                 hand_calculated_embedding_values,\n                                 values_in_hook,\n                                 tol=DEFAULT_TOL):\n    \"\"\"Assert the embedding variables after training one step.\"\"\"\n\n    expected_embedding_var_values = []\n    # Before training, all zeros (zeros_initializer)\n    expected_embedding_var_values.append(np.zeros((BUCKET_SIZE, EMBEDDING_DIM)))\n\n    after_training_var_values = np.zeros((BUCKET_SIZE, EMBEDDING_DIM))\n    embedding_row_value_after_one_step = [\n        get_embedding_update(g) for g in gradients_for_embedding\n    ]\n\n    for i in range(len(embedding_row_value_after_one_step)):\n      after_training_var_values[i][:] = embedding_row_value_after_one_step[i]\n    expected_embedding_var_values.append(after_training_var_values)\n\n    # Check against hand calculated value.\n    self.assertAllClose(hand_calculated_embedding_values,\n                        embedding_row_value_after_one_step)\n    # Check against the value recorded during training.\n    self.assertAllClose(\n        expected_embedding_var_values, values_in_hook, atol=tol, rtol=tol)\n\n  def assert_embedding_slot_variables(self, gradients_for_embedding,\n                                      hand_calculated_embedding_slot_values,\n                                      values_in_hook, tol):\n    \"\"\"Assert the embedding slot variables after training one step.\"\"\"\n\n    expected_embedding_slot_var_values = []\n    # Before training, all same (ADADGRAD_INIT_VALUE)\n    expected_embedding_slot_var_values.append(\n        np.ones((BUCKET_SIZE, EMBEDDING_DIM)) * ADADGRAD_INIT_VALUE)\n\n    after_training_slot_var_values = np.zeros((BUCKET_SIZE, EMBEDDING_DIM))\n    accumulator_sum = [\n        ADADGRAD_INIT_VALUE + g * g for g in gradients_for_embedding\n    ]\n    for i in range(len(accumulator_sum)):\n      after_training_slot_var_values[i][:] = accumulator_sum[i]\n\n    expected_embedding_slot_var_values.append(after_training_slot_var_values)\n\n    # Check against hand calculated value.\n    self.assertAllClose(hand_calculated_embedding_slot_values, accumulator_sum)\n    # Check against the value recorded during training.\n    self.assertAllClose(\n        expected_embedding_slot_var_values, values_in_hook, atol=tol, rtol=tol)\n\n  def test_one_sample_per_core(self):\n    use_tpu = True\n    per_core_batch_size = 1\n    num_shards = FLAGS.test_num_shards\n    batch_size = num_shards * per_core_batch_size\n\n    hook = _EmbeddingVariableHook(use_tpu=use_tpu)\n\n    estimator = get_estimator(use_tpu, self._model_dir, get_feature_columns(),\n                              batch_size)\n    estimator.train(self._input_fn, steps=1, hooks=[hook])\n\n    # After training one step, the core 0 gets one sample with ID 0, and core 1\n    # gets one sample with ID 1. So, all other IDs' embedding vars remain as\n    # zeros.\n    gradients_for_embedding = [\n        get_activation_gradients(label=1),\n        get_activation_gradients(label=2)\n    ]\n    gradients_for_embedding += [0] * (BUCKET_SIZE - batch_size)\n    # Scale the gradients by 1/num_shards as CrossShardOptimizer scales the\n    # loss for MEAN reduction.\n    gradients_for_embedding = np.array(gradients_for_embedding) / num_shards\n\n    hand_calculated_embedding_values = [0] * BUCKET_SIZE\n    # Gradients are 6.0 and 4.0. the embedding value should\n    #    - LEARNING_RATE* x / math.sqrt(ADADGRAD_INIT_VALUE + x*x)\n    hand_calculated_embedding_values[:2] = [\n        -0.1198336797537491, -0.11962674870701442\n    ]\n\n    self.assert_embedding_variables(\n        gradients_for_embedding=gradients_for_embedding,\n        hand_calculated_embedding_values=hand_calculated_embedding_values,\n        values_in_hook=hook.var_values,\n        tol=DEFAULT_TOL)\n\n    hand_calculated_embedding_slot_values = [ADADGRAD_INIT_VALUE] * BUCKET_SIZE\n    hand_calculated_embedding_slot_values[0] += 6.0**2\n    hand_calculated_embedding_slot_values[1] += 4.0**2\n\n    self.assert_embedding_slot_variables(\n        gradients_for_embedding=gradients_for_embedding,\n        hand_calculated_embedding_slot_values=(\n            hand_calculated_embedding_slot_values),\n        values_in_hook=hook.slot_var_values,\n        tol=DEFAULT_TOL)\n\n  def test_one_sample_per_core_tpu_vs_cpu(self):\n    use_tpu = True\n    per_core_batch_size = 1\n    num_shards = FLAGS.test_num_shards\n    batch_size = num_shards * per_core_batch_size\n\n    # TPU\n    tpu_hook = _EmbeddingVariableHook(use_tpu=use_tpu)\n    estimator = get_estimator(use_tpu, self._model_dir + '_tpu',\n                              get_feature_columns(), batch_size)\n    estimator.train(self._input_fn, steps=1, hooks=[tpu_hook])\n\n    # CPU\n    use_tpu = False\n    cpu_hook = _EmbeddingVariableHook(use_tpu=use_tpu)\n\n    cpu_estimator = get_estimator(use_tpu, self._model_dir + '_cpu',\n                                  get_feature_columns(), batch_size)\n    cpu_estimator.train(self._input_fn, steps=1, hooks=[cpu_hook])\n\n    tol = DEFAULT_TOL\n    self.assertAllClose(\n        tpu_hook.var_values, cpu_hook.var_values, atol=tol, rtol=tol)\n    self.assertAllClose(\n        tpu_hook.slot_var_values, cpu_hook.slot_var_values, atol=tol, rtol=tol)\n\n    # Also check dense.\n    self.assertAllClose(\n        tpu_hook.bias_values, cpu_hook.bias_values, atol=tol, rtol=tol)\n\n  def test_multi_samples_per_core(self):\n    use_tpu = True\n    per_core_batch_size = 2\n    num_shards = FLAGS.test_num_shards\n    batch_size = num_shards * per_core_batch_size\n\n    hook = _EmbeddingVariableHook(use_tpu=use_tpu)\n\n    estimator = get_estimator(use_tpu, self._model_dir, get_feature_columns(),\n                              batch_size)\n    estimator.train(self._input_fn, steps=1, hooks=[hook])\n\n    # After training one step, the core 0 gets two samples with ID 0 and 1. For\n    # core 1 gets two samples with ID 2 and 3. So, all other IDs' embedding vars\n    # remain as zeros.\n    gradients_for_embedding = [\n        get_activation_gradients(label=1),\n        get_activation_gradients(label=2),\n        get_activation_gradients(label=3),\n        get_activation_gradients(label=4)\n    ]\n    gradients_for_embedding += [0] * (BUCKET_SIZE - batch_size)\n    gradients_for_embedding = np.array(gradients_for_embedding)\n    # Scale the gradients by 1/ per_core_batch_size, as for each core the loss\n    # is mean loss.\n    gradients_for_embedding /= per_core_batch_size\n    # Further scale the gradients by 1/num_shards as CrossShardOptimizer scales\n    # the loss for MEAN reduction.\n    gradients_for_embedding /= num_shards\n\n    # Gradients are 3.0, 2.0, 1.0, and 0.0. the embedding value should\n    #    - LEARNING_RATE* x / math.sqrt(ADADGRAD_INIT_VALUE + x*x)\n    hand_calculated_embedding_values = [0] * BUCKET_SIZE\n    hand_calculated_embedding_values[:2] = [-0.119338837, -0.118527551]\n    hand_calculated_embedding_values[2:4] = [-0.1144155107094, 0]\n\n    self.assert_embedding_variables(\n        gradients_for_embedding=gradients_for_embedding,\n        hand_calculated_embedding_values=hand_calculated_embedding_values,\n        values_in_hook=hook.var_values,\n        tol=DEFAULT_TOL)\n\n    hand_calculated_embedding_slot_values = [ADADGRAD_INIT_VALUE] * BUCKET_SIZE\n    hand_calculated_embedding_slot_values[0] += 3.0**2\n    hand_calculated_embedding_slot_values[1] += 2.0**2\n    hand_calculated_embedding_slot_values[2] += 1.0**2\n\n    self.assert_embedding_slot_variables(\n        gradients_for_embedding=gradients_for_embedding,\n        hand_calculated_embedding_slot_values=(\n            hand_calculated_embedding_slot_values),\n        values_in_hook=hook.slot_var_values,\n        tol=DEFAULT_TOL)\n\n  def test_multi_samples_per_core_tpu_vs_cpu(self):\n    use_tpu = True\n    per_core_batch_size = 2\n    num_shards = FLAGS.test_num_shards\n    batch_size = num_shards * per_core_batch_size\n\n    # TPU\n    tpu_hook = _EmbeddingVariableHook(use_tpu=use_tpu)\n    estimator = get_estimator(use_tpu, self._model_dir + '_tpu',\n                              get_feature_columns(), batch_size)\n    estimator.train(self._input_fn, steps=1, hooks=[tpu_hook])\n\n    # CPU\n    use_tpu = False\n    cpu_hook = _EmbeddingVariableHook(use_tpu=use_tpu)\n    cpu_estimator = get_estimator(use_tpu, self._model_dir + '_cpu',\n                                  get_feature_columns(), batch_size)\n    cpu_estimator.train(self._input_fn, steps=1, hooks=[cpu_hook])\n\n    tol = DEFAULT_TOL\n    self.assertAllClose(\n        tpu_hook.var_values, cpu_hook.var_values, atol=tol, rtol=tol)\n    self.assertAllClose(\n        tpu_hook.slot_var_values,\n        cpu_hook.slot_var_values,\n        atol=tol,\n        rtol=tol)\n\n    # Also check dense.\n    self.assertAllClose(\n        tpu_hook.bias_values, cpu_hook.bias_values, atol=tol, rtol=tol)\n\n\nclass TPUEstimatorGradientsWithIdCollisionTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def _input_fn(self, params):\n    # This input_fn is expected to be called twice each having a batch_size 2.\n    # The first output will be\n    #   label = [1, 2]\n    #   sparse inputs: SarseTensorValue(\n    #       indices=array([[0, 0], [0, 1],\n    #                      [1, 0], [1, 1]]),\n    #       values=array([0, 1,\n    #                     1, 2]),\n    #       dense_shape=array([2, 2]))\n    #\n    # The second output will be\n    #   label = [3, 4]\n    #   sparse inputs: SarseTensorValue(\n    #       indices=array([[0, 0], [0, 1],\n    #                      [1, 0], [1, 1]]),\n    #       values=array([1, 2,\n    #                     2, 3]),\n    #       dense_shape=array([2, 2]))\n    #\n    # So, each sample has two ids. Each core gets two samples, which share some\n    # ids. And different cores share ids also.\n    batch_size = params['batch_size']\n    self.assertTrue(batch_size == 2 or batch_size == 4)\n\n    ds = tf.data.Dataset.range(8)\n\n    def _map_fn(index):\n      x = tf.floordiv(index, 2)\n      y = tf.floormod(index, 2)\n\n      label = tf.cast(index + 1, tf.float32)\n      label = tf.reshape(label, [1])\n\n      target_dense = tf.stack([x + y, x + y + 1])\n      return ({KEY_NAME: dense_to_sparse(target_dense, tf.int64)}, label)\n\n    ds = ds.map(_map_fn)\n    ds = ds.batch(batch_size, drop_remainder=True)\n    return ds\n\n  def test_input_fn(self):\n    ds = self._input_fn({'batch_size': 2})\n    gn = ds.make_one_shot_iterator().get_next()\n\n    with tf.Session() as sess:\n      # First call\n      features, label = sess.run(gn)\n      sparse_tensor = features[KEY_NAME]\n      self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1]],\n                          sparse_tensor.indices)\n      self.assertAllEqual([\n          0,\n          1,\n          1,\n          2,\n      ], sparse_tensor.values)\n      self.assertAllEqual([[1], [2]], label)\n\n      # second call\n      features, label = sess.run(gn)\n      sparse_tensor = features[KEY_NAME]\n      self.assertAllEqual([[0, 0], [0, 1], [1, 0], [1, 1]],\n                          sparse_tensor.indices)\n      self.assertAllEqual([\n          1,\n          2,\n          2,\n          3,\n      ], sparse_tensor.values)\n      self.assertAllEqual([[3], [4]], label)\n\n  def test_adagrad_opt_embedding_variables_on_tpu(self):\n    use_tpu = True\n    per_core_batch_size = 2\n    num_shards = FLAGS.test_num_shards\n    batch_size = num_shards * per_core_batch_size\n\n    hook = _EmbeddingVariableHook(use_tpu=use_tpu)\n\n    estimator = get_estimator(use_tpu, self._model_dir, get_feature_columns(),\n                              batch_size)\n    estimator.train(self._input_fn, steps=1, hooks=[hook])\n\n    final_step = 1\n    tol = DEFAULT_TOL\n\n    # In this parcticular example, the gradient w.r.t. each activation is not\n    # gradient w.r.t. embedding due to the combiner.\n    unscaled_gradient_for_activation = [\n        get_activation_gradients(label=1),\n        get_activation_gradients(label=2),\n        get_activation_gradients(label=3),\n        get_activation_gradients(label=4),\n    ]\n    self.assertAllEqual([12., 8., 4.0, 0.0], unscaled_gradient_for_activation)\n\n    # Due to reduce_mean and 1/num_shards scaling, the embeddings gradients\n    # are 3.0, 2.0, 1.0, 0.0 as num of samples per core is 2 and\n    # num_shards (number of cores) is 2.\n\n    # Now calcuates the gradients for embedding vars and accumulator for each\n    # var.\n    #\n    # Note the IDs for each core are\n\n    # Core 0 sample 0  IDs: [0  1]\n    # Core 0 sample 1  IDs: [1  2]\n\n    # Core 1 sample 0  IDs: [1  2]\n    # Core 1 sample 1  IDs: [2  3]\n\n    # For embedding ID 0, it appears only in the first sample of the first core.\n    # So, its gradient is 3.0 / 2, where 1/2 is due to the mean combiner.\n    gradient_for_id_0 = 1.5\n    accumuator_for_id_0 = gradient_for_id_0**2 + ADADGRAD_INIT_VALUE\n\n    self.assertAllClose(\n        accumuator_for_id_0,\n        hook.slot_var_values[final_step][0][0],\n        atol=tol, rtol=tol)\n\n    # embedding_update = - LR * g / (init_accm + g**2)\n    gradient_update_for_id_0 = get_embedding_update(gradient_for_id_0)\n    self.assertAllClose(\n        gradient_update_for_id_0,\n        hook.var_values[final_step][0][0],\n        rtol=tol, atol=tol)\n\n    # Similarly, for embedding ID 3, it appears only in the second sample of the\n    # second core.  So, its gradient is 0.0 / 2, where 1/2 is due to the mean\n    # combiner.\n    gradient_for_id_3 = 0\n    accumuator_for_id_3 = gradient_for_id_3**2 + ADADGRAD_INIT_VALUE\n\n    self.assertAllClose(\n        accumuator_for_id_3,\n        hook.slot_var_values[final_step][3][0],\n        atol=tol, rtol=tol)\n\n    # embedding_update = - LR * g / (init_accm + g**2)\n    gradient_update_for_id_3 = get_embedding_update(gradient_for_id_3)\n    self.assertAllClose(\n        gradient_update_for_id_3,\n        hook.var_values[final_step][3][0],\n        rtol=tol, atol=tol)\n\n    # For embedding ID 2, it appears in\n    #  - second sample of first core\n    #  - first sample of second core\n    #  - second sample of second core\n    #\n    # Note that the gradients of the second activation of the second core is 0.\n    # So, equivalent, it is same as\n    #\n    #  - second sample of first core -> gradient = 0.5 * 2.0 = 1.0\n    #  - first sample of second core -> gradient = 0.5 * 1.0 = 0.5\n    gradient_for_id_2_in_core_0 = 1.0\n    gradient_for_id_2_in_core_1 = 0.5\n    accumuator_for_id_2 = (\n        ADADGRAD_INIT_VALUE + gradient_for_id_2_in_core_0**2 +\n        gradient_for_id_2_in_core_1**2)\n\n    self.assertAllClose(\n        accumuator_for_id_2,\n        hook.slot_var_values[final_step][2][0],\n        atol=tol, rtol=tol)\n\n    # embedding_update = (\n    #    - LR * g1 / (init_accum + g1**2)\n    #    - LR * g2 / (init_accum + g1**2 + g2**2)\n    gradient_update_for_id_2_after_apply_core_0 = get_embedding_update(\n        gradient_for_id_2_in_core_0)\n    accum_inc = gradient_for_id_2_in_core_0**2\n\n    gradient_update_for_id_2_after_apply_core_1 = get_embedding_update(\n        gradient_for_id_2_in_core_1, previous_accum_inc=accum_inc)\n\n    embedding_update_for_id_2 = (\n        gradient_update_for_id_2_after_apply_core_0 +\n        gradient_update_for_id_2_after_apply_core_1)\n\n    self.assertAllClose(\n        embedding_update_for_id_2,\n        hook.var_values[final_step][2][0],\n        rtol=tol, atol=tol)\n\n    # For embedding ID 1, it appears in\n    #  - first sample of first core\n    #  - second sample of first core\n    #  - first sample of second core\n    #\n    # So, the gradient for each sample\n    #\n    #  - first sample of first core -> gradient = 0.5 * 3.0 = 1.5\n    #  - second sample of first core -> gradient = 0.5 * 2.0 = 1.0\n    #  - first sample of second core -> gradient = 0.5 * 1.0 = 0.5\n    #\n    # Baracore combines the gradients in single core and then applies them core\n    # by core.\n    gradient_for_id_1_in_core_0 = 1.5 + 1.0\n    gradient_for_id_1_in_core_1 = 0.5\n    accumuator_for_id_1 = (\n        ADADGRAD_INIT_VALUE + gradient_for_id_1_in_core_0**2 +\n        gradient_for_id_1_in_core_1**2)\n\n    self.assertAllClose(\n        accumuator_for_id_1,\n        hook.slot_var_values[final_step][1][0],\n        atol=tol, rtol=tol)\n\n    # ID 1 resides on Core 1, so the updates from Core 1 are applied first.\n    # embedding_update = (\n    #    - LR * g1 / (init_accum + g1**2)\n    #    - LR * g2 / (init_accum + g1**2 + g2**2)\n    gradient_update_for_id_1_after_apply_core_1 = get_embedding_update(\n        gradient_for_id_1_in_core_1)\n    accum_inc = gradient_for_id_1_in_core_1**2\n\n    gradient_update_for_id_1_after_apply_core_0 = get_embedding_update(\n        gradient_for_id_1_in_core_0, previous_accum_inc=accum_inc)\n\n    embedding_update_for_id_1 = (\n        gradient_update_for_id_1_after_apply_core_0 +\n        gradient_update_for_id_1_after_apply_core_1)\n\n    self.assertAllClose(\n        embedding_update_for_id_1,\n        hook.var_values[final_step][1][0],\n        rtol=tol, atol=tol)\n\n  def test_adagrad_opt_embedding_variables_on_cpu(self):\n    use_tpu = False\n    per_core_batch_size = 2\n    num_shards = FLAGS.test_num_shards\n    batch_size = num_shards * per_core_batch_size\n\n    hook = _EmbeddingVariableHook(use_tpu=use_tpu)\n\n    estimator = get_estimator(use_tpu, self._model_dir, get_feature_columns(),\n                              batch_size)\n    estimator.train(self._input_fn, steps=1, hooks=[hook])\n\n    final_step = 1\n    tol = DEFAULT_TOL\n\n    # In this CPU example, the gradients for embedding rows are same as the\n    # above: not only the sample loss and gradient, but also the scaling. The\n    # only difference is CPU combines all gradients in one update; while TPU\n    # updates the gradients core by core.\n    #\n    # For ID 0: gradient = 0.5 * 3.0 = 1.5\n    # For ID 1: gradient = 0.5 * 3.0 + 0.5 * 2.0 + 0.5 * 1.0 = 3.0\n    # For ID 2: gradient = 0.5 * 2.0 + 0.5 * 1.0 = 1.5\n    # For ID 3: gradient = 0.5 * 0.0 = 0.0\n\n    gradients_for_embedding = np.array([1.5, 3.0, 1.5, 0])\n\n    # Check accumulator after one step.\n    for index, gradient in enumerate(gradients_for_embedding):\n      accumulator = ADADGRAD_INIT_VALUE + gradient**2\n      self.assertAllClose(\n          accumulator,\n          hook.slot_var_values[final_step][index][0],\n          atol=tol, rtol=tol)\n\n    # Check embedding value after one step.\n    for index, gradient in enumerate(gradients_for_embedding):\n      embedding_update = get_embedding_update(gradient)\n      self.assertAllClose(\n          embedding_update,\n          hook.var_values[final_step][index][0],\n          atol=tol, rtol=tol)\n\n  def test_sgd_opt_embedding_variables_on_cpu(self):\n    use_tpu = False\n    per_core_batch_size = 2\n    num_shards = FLAGS.test_num_shards\n    batch_size = num_shards * per_core_batch_size\n\n    hook = _EmbeddingVariableHook(use_tpu=use_tpu, include_slot_vars=False)\n\n    estimator = get_estimator(\n        use_tpu,\n        self._model_dir,\n        get_feature_columns(),\n        batch_size,\n        optimizer_type='sgd')\n    estimator.train(self._input_fn, steps=1, hooks=[hook])\n\n    final_step = 1\n    tol = DEFAULT_TOL\n\n    # In this CPU example, the gradients for embedding rows are same as the\n    # above: not only the sample loss and gradient, but also the scaling. The\n    # only difference is CPU combines all gradients in one update; while TPU\n    # updates the gradients core by core.\n    #\n    # For ID 0: gradient = 0.5 * 3.0 = 1.5\n    # For ID 1: gradient = 0.5 * 3.0 + 0.5 * 2.0 + 0.5 * 1.0 = 3.0\n    # For ID 2: gradient = 0.5 * 2.0 + 0.5 * 1.0 = 1.5\n    # For ID 3: gradient = 0.5 * 0.0 = 0.0\n\n    gradients_for_embedding = np.array([1.5, 3.0, 1.5, 0])\n    # SGD has simple update rule, w += - lr * g\n    embedding_update = [LEARNING_RATE * (-g) for g in gradients_for_embedding]\n\n    # Check embedding value after one step.\n    for index in range(len(gradients_for_embedding)):\n      self.assertAllClose(\n          embedding_update[index],\n          hook.var_values[final_step][index][0],\n          atol=tol, rtol=tol)\n\n  def test_sgd_opt_embedding_variables_cpu_vs_tpu(self):\n    # For sgd, cpu and tpu should agree.\n    per_core_batch_size = 2\n    num_shards = FLAGS.test_num_shards\n    batch_size = num_shards * per_core_batch_size\n\n    use_tpu = False\n    cpu_hook = _EmbeddingVariableHook(use_tpu=use_tpu, include_slot_vars=False)\n    cpu_estimator = get_estimator(\n        use_tpu,\n        self._model_dir + '_cpu',\n        get_feature_columns(),\n        batch_size,\n        optimizer_type='sgd')\n    cpu_estimator.train(self._input_fn, steps=1, hooks=[cpu_hook])\n\n    use_tpu = True\n    tpu_hook = _EmbeddingVariableHook(use_tpu=use_tpu, include_slot_vars=False)\n    estimator = get_estimator(\n        use_tpu,\n        self._model_dir + '_tpu',\n        get_feature_columns(),\n        batch_size,\n        optimizer_type='sgd')\n    estimator.train(self._input_fn, steps=1, hooks=[tpu_hook])\n\n    tol = DEFAULT_TOL\n    self.assertAllClose(\n        cpu_hook.var_values, tpu_hook.var_values, atol=tol, rtol=tol)\n\n    self.assertAllClose(\n        cpu_hook.bias_values,\n        tpu_hook.bias_values,\n        atol=tol, rtol=tol)\n\n    # Test gradient multiplier.\n    def grad_multiplier_fn(global_step):\n      # First global step is 0.\n      return tf.cast(global_step + 1, tf.float32) * GRAD_MULTIPLIER\n\n    tpu_hook2 = _EmbeddingVariableHook(use_tpu=use_tpu, include_slot_vars=False)\n    estimator2 = get_estimator(\n        use_tpu,\n        self._model_dir + '_tpu_grad_multiplier',\n        get_feature_columns(),\n        batch_size,\n        optimizer_type='sgd',\n        grad_multiplier_fn=grad_multiplier_fn)\n    estimator2.train(self._input_fn, steps=1, hooks=[tpu_hook2])\n\n    tol = DEFAULT_TOL\n    self.assertAllClose([v * GRAD_MULTIPLIER for v in cpu_hook.var_values],\n                        tpu_hook2.var_values,\n                        atol=tol * GRAD_MULTIPLIER,\n                        rtol=tol * GRAD_MULTIPLIER)\n\n    self.assertAllClose(\n        cpu_hook.bias_values, tpu_hook2.bias_values, atol=tol, rtol=tol)\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator_input_v2_test.py",
    "content": "# Copyright 2021 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for TPUEstimator.\"\"\"\n\nfrom absl import flags\nfrom absl.testing import parameterized\nimport numpy as np\nimport six\nimport tensorflow.compat.v1 as tf\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\n\nFLAGS = flags.FLAGS\n\n_TRAIN = model_fn_lib.ModeKeys.TRAIN\n_EVAL = model_fn_lib.ModeKeys.EVAL\n_PREDICT = model_fn_lib.ModeKeys.PREDICT\n\n\ndef create_run_config(iterations_per_loop, **kwargs):\n  return tpu_config.RunConfig(\n      master='',\n      tpu_config=tpu_config.TPUConfig(\n          iterations_per_loop=iterations_per_loop,\n          **kwargs),\n  )\n\n\ndef dense_computation(features):\n  return tf_keras_v1.__internal__.legacy.layers.dense(\n      features, 1, kernel_initializer=tf.zeros_initializer())\n\n\ndef model_fn_global_step_incrementer(features, labels, mode, params):\n  del params\n  loss = None\n  train_op = None\n  predictions = dense_computation(features)\n  if mode != _PREDICT:\n    loss = tf.losses.mean_squared_error(labels, predictions)\n    optimizer = tf.tpu.CrossShardOptimizer(\n        tf.train.GradientDescentOptimizer(learning_rate=0.5))\n    train_op = optimizer.minimize(loss, tf.train.get_global_step())\n  return tpu_estimator.TPUEstimatorSpec(\n      mode,\n      loss=loss,\n      train_op=train_op,\n      predictions={'predictions': predictions},\n      export_outputs={\n          'test': export_output.PredictOutput({\n              'prediction': predictions\n          })\n      })\n\n\ndef dummy_input_fn_with_dataset(batch_size, fea_len=1, repeat=True, x=None):\n  if x is None:\n    x = np.random.normal(size=[batch_size, fea_len]).astype(np.float32)\n  labels = [[2.0]] * batch_size\n\n  dataset1 = tf.data.Dataset.from_tensor_slices(x)\n  dataset2 = tf.data.Dataset.from_tensor_slices(labels)\n  dataset = tf.data.Dataset.zip((dataset1, dataset2))\n  if repeat:\n    dataset = dataset.repeat()\n  dataset = dataset.batch(batch_size, drop_remainder=True)\n\n  def _map(x, y):\n    return x, y\n\n  return dataset.map(_map)\n\n\nclass TpuEstimatorInputV2Test(parameterized.TestCase, tf.test.TestCase):\n\n  @parameterized.parameters((2, 1), (None, 2))\n  def test_batch_size(self, num_cores_per_replica, num_shards):\n    input_fn_call_count = [0]\n    run_config = create_run_config(\n        iterations_per_loop=4,\n        num_cores_per_replica=num_cores_per_replica,\n        num_shards=num_shards,\n        per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2)\n\n    def _input_fn(params):\n      input_fn_call_count[0] += 1\n      expected_batch_size = 128 // num_shards\n      self.assertEqual(expected_batch_size, params['batch_size'])\n      return dummy_input_fn_with_dataset(batch_size=params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn_global_step_incrementer,\n        config=run_config,\n        train_batch_size=128)\n    self.assertEqual(0, input_fn_call_count[0])\n    est.train(_input_fn, steps=1)\n    self.assertEqual(1, input_fn_call_count[0])\n\n  def test_run_spatial_partition(self):\n    input_fn_call_count = [0]\n    run_config = create_run_config(\n        iterations_per_loop=4,\n        num_cores_per_replica=2,\n        num_shards=1,\n        input_partition_dims=[[1, 2], None],\n        per_host_input_for_training=(\n            tpu_config.InputPipelineConfig.PER_HOST_V2))\n\n    def _input_fn(params):\n      input_fn_call_count[0] += 1\n      return dummy_input_fn_with_dataset(\n          batch_size=params['batch_size'], fea_len=2)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn_global_step_incrementer,\n        config=run_config,\n        train_batch_size=128)\n    self.assertEqual(0, input_fn_call_count[0])\n    est.train(_input_fn, steps=1)\n    self.assertEqual(1, input_fn_call_count[0])\n\n  def test_predict_mode(self):\n    input_fn_call_count = [0]\n    predict_batch_size = 128\n    run_config = create_run_config(\n        iterations_per_loop=4,\n        num_cores_per_replica=2,\n        num_shards=1,\n        input_partition_dims=[[1, 2], None],\n        per_host_input_for_training=(\n            tpu_config.InputPipelineConfig.PER_HOST_V2))\n\n    def _input_fn(params):\n      input_fn_call_count[0] += 1\n      return dummy_input_fn_with_dataset(\n          batch_size=params['batch_size'], fea_len=2)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn_global_step_incrementer,\n        config=run_config,\n        train_batch_size=128,\n        predict_batch_size=predict_batch_size)\n\n    self.assertEqual(0, input_fn_call_count[0])\n\n    predictor = est.predict(_input_fn, yield_single_examples=False)\n    prediction = six.next(predictor)\n\n    self.assertEqual(1, input_fn_call_count[0])\n    self.assertIn('predictions', prediction)\n    self.assertEqual((predict_batch_size, 1), prediction['predictions'].shape)\n\n    predictor = est.predict(_input_fn, yield_single_examples=True)\n    prediction = six.next(predictor)\n\n    self.assertEqual(2, input_fn_call_count[0])\n    self.assertIn('predictions', prediction)\n    self.assertEqual((1,), prediction['predictions'].shape)\n\n  def test_evaluate_mode(self):\n    input_fn_call_count = [0]\n    eval_batch_size = 128\n    run_config = create_run_config(\n        iterations_per_loop=4,\n        num_cores_per_replica=2,\n        num_shards=1,\n        input_partition_dims=[[1, 2], None],\n        per_host_input_for_training=(\n            tpu_config.InputPipelineConfig.PER_HOST_V2))\n\n    def _input_fn(params):\n      input_fn_call_count[0] += 1\n      return dummy_input_fn_with_dataset(\n          batch_size=params['batch_size'], fea_len=2)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn_global_step_incrementer,\n        config=run_config,\n        train_batch_size=128,\n        eval_batch_size=eval_batch_size)\n\n    self.assertEqual(0, input_fn_call_count[0])\n    est.evaluate(_input_fn, steps=1)\n    self.assertEqual(1, input_fn_call_count[0])\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator_integration_test.py",
    "content": "# Copyright 2021 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for TPUEstimator.\"\"\"\n\nimport contextlib\nimport tempfile\nfrom absl import flags\nimport numpy as np\nimport tensorflow.compat.v1 as tf\n# pylint: disable=g-direct-tensorflow-import\nfrom tensorflow.core.example import example_pb2\nfrom tensorflow.core.example import feature_pb2\n\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\n# pylint: enable=g-direct-tensorflow-import\n\nflags.DEFINE_integer('test_num_shards', 8, 'number of replicas to test')\n\n\nFLAGS = flags.FLAGS\n\n_TRAIN = model_fn_lib.ModeKeys.TRAIN\n_EVAL = model_fn_lib.ModeKeys.EVAL\n_PREDICT = model_fn_lib.ModeKeys.PREDICT\n\n_PER_HOST = 'per_host_sharding'\n_PER_SHARD = 'per_shard_sharding'\n_UNSHARDED = 'unsharded'\n_INPUT_PIPELINE_WITH_QUEUE_RUNNER = (\n    'Input pipeline contains one or more QueueRunners')\n\n\ndef dense_computation(features):\n  return tf_keras_v1.__internal__.legacy.layers.dense(\n      features['x'], 1, kernel_initializer=tf.zeros_initializer())\n\n\ndef model_fn_global_step_incrementer(features, labels, mode, params):\n  del params\n  loss = None\n  train_op = None\n  predictions = dense_computation(features)\n  if mode != _PREDICT:\n    loss = tf.losses.mean_squared_error(labels, predictions)\n    optimizer = tf.tpu.CrossShardOptimizer(\n        tf.train.GradientDescentOptimizer(learning_rate=0.5))\n    train_op = optimizer.minimize(loss, tf.train.get_global_step())\n  return tpu_estimator.TPUEstimatorSpec(\n      mode,\n      loss=loss,\n      train_op=train_op,\n      predictions={'predictions': predictions},\n      export_outputs={\n          'test': export_output.PredictOutput({\n              'prediction': predictions\n          })\n      })\n\n\ndef dummy_input_fn_with_dataset(batch_size, repeat=True, x=None):\n  if x is None:\n    x = np.random.normal(size=[batch_size, 1]).astype(np.float32)\n  labels = [[2.0]] * batch_size\n\n  dataset1 = tf.data.Dataset.from_tensor_slices(x)\n  dataset2 = tf.data.Dataset.from_tensor_slices(labels)\n  dataset = tf.data.Dataset.zip((dataset1, dataset2))\n  if repeat:\n    dataset = dataset.repeat()\n  dataset = dataset.batch(batch_size, drop_remainder=True)\n\n  def _map(x, y):\n    return {'x': x}, y\n\n  return dataset.map(_map)\n\n\ndef dummy_input_fn(batch_size, repeat=True):\n  dataset = dummy_input_fn_with_dataset(batch_size, repeat)\n  iterator = dataset.make_one_shot_iterator()\n  return iterator.get_next()\n\n\ndef create_run_config(iterations_per_loop, **kwargs):\n  return tpu_config.RunConfig(\n      master='',\n      tpu_config=tpu_config.TPUConfig(\n          iterations_per_loop=iterations_per_loop,\n          num_shards=FLAGS.test_num_shards,\n          **kwargs),\n  )\n\n\nclass TPUEstimatorIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._recorded_input_fn_invoke_metadata = {\n        _TRAIN: {'called_count': 0, 'batch_size': None},\n        _EVAL: {'called_count': 0, 'batch_size': None},\n        _PREDICT: {'called_count': 0, 'batch_size': None}\n    }\n    self._data = np.linspace(0., 1., 100, dtype=np.float32).reshape(-1, 1)\n    self._export_mode = False\n\n  @contextlib.contextmanager\n  def export_mode(self):\n    \"\"\"Enable the export mode for model_fn.\"\"\"\n    # Inside the model_fn, the test will check the batch size passed via params.\n    # However, export mode should not have that. It is infeasible for model_fn\n    # to distinguish the predict vs export mode today. So, this contextmanager\n    # helps the model_fn to do that.\n    self._export_mode = True\n    yield\n    self._export_mode = False\n\n  def assertInputFnCalledCountAndBatch(self, expected_called_count,\n                                       expected_batch_size):\n    real_called_count = {k: v['called_count'] for k, v in\n                         self._recorded_input_fn_invoke_metadata.items()}\n    real_batch_size = {k: v['batch_size'] for k, v in\n                       self._recorded_input_fn_invoke_metadata.items()}\n    self.assertEqual(expected_called_count, real_called_count)\n    self.assertEqual(expected_batch_size, real_batch_size)\n\n  def _generate_expected_batch_size_and_called_count(\n      self,\n      num_shards,\n      train_batch_size,\n      eval_batch_size,\n      predict_batch_size,\n      train_sharding_policy=_UNSHARDED,\n      eval_sharding_policy=_UNSHARDED,\n      predict_sharding_policy=None):\n\n    expected_batch_size_for_model_fn = {}\n    expected_batch_size_for_input_fn = {}\n    expected_called_count_for_input_fn = {}\n\n    if train_sharding_policy == _PER_SHARD:\n      self.assertEqual(0, train_batch_size % num_shards)\n      expected_batch_size_for_model_fn[_TRAIN] = train_batch_size // num_shards\n      expected_batch_size_for_input_fn[_TRAIN] = train_batch_size // num_shards\n      expected_called_count_for_input_fn[_TRAIN] = num_shards\n    elif train_sharding_policy == _PER_HOST:\n      self.assertEqual(0, train_batch_size % num_shards)\n      expected_batch_size_for_model_fn[_TRAIN] = train_batch_size // num_shards\n      expected_batch_size_for_input_fn[_TRAIN] = train_batch_size\n      expected_called_count_for_input_fn[_TRAIN] = 1\n    else:\n      expected_batch_size_for_model_fn[_TRAIN] = train_batch_size\n      expected_batch_size_for_input_fn[_TRAIN] = train_batch_size\n      expected_called_count_for_input_fn[_TRAIN] = 1\n\n    if eval_sharding_policy == _PER_HOST:\n      self.assertEqual(0, train_batch_size % num_shards)\n      expected_batch_size_for_model_fn[_EVAL] = eval_batch_size // num_shards\n      expected_batch_size_for_input_fn[_EVAL] = eval_batch_size\n      expected_called_count_for_input_fn[_EVAL] = 1\n    else:\n      expected_batch_size_for_model_fn[_EVAL] = eval_batch_size\n      expected_batch_size_for_input_fn[_EVAL] = eval_batch_size\n      expected_called_count_for_input_fn[_EVAL] = 1\n\n    if predict_sharding_policy is None:\n      # On CPU.\n      expected_batch_size_for_model_fn[_PREDICT] = predict_batch_size\n      expected_batch_size_for_input_fn[_PREDICT] = predict_batch_size\n      expected_called_count_for_input_fn[_PREDICT] = 1\n    else:\n      expected_batch_size_for_model_fn[_PREDICT] = (\n          predict_batch_size // num_shards)\n      expected_batch_size_for_input_fn[_PREDICT] = predict_batch_size\n      expected_called_count_for_input_fn[_PREDICT] = 1\n\n    return (expected_batch_size_for_model_fn, expected_batch_size_for_input_fn,\n            expected_called_count_for_input_fn)\n\n  def _wrap_input_fn_with_batch_size(self, batch_size, input_fn):\n    def _input_fn(params):\n      self.assertNotIn('batch_size', params)\n      params['batch_size'] = batch_size\n      return input_fn(params)\n    return _input_fn\n\n  def _make_input_fn(self, mode, repeat=False, take=None):\n    metadata = self._recorded_input_fn_invoke_metadata[mode]\n    def _input_fn(params):\n      metadata['called_count'] += 1\n      batch_size = params['batch_size']\n\n      if metadata['batch_size'] is None:\n        metadata['batch_size'] = batch_size\n      else:\n        self.assertEqual(batch_size, metadata['batch_size'])\n\n      dataset1 = tf.data.Dataset.from_tensor_slices(self._data)\n      dataset2 = tf.data.Dataset.from_tensor_slices(self._data)\n      dataset = tf.data.Dataset.zip((dataset1, dataset2))\n\n      if repeat:\n        dataset = dataset.repeat()\n\n      dataset = dataset.batch(batch_size)\n\n      if take:\n        dataset = dataset.take(take)\n\n      def _map_fn(x, y):\n        x.set_shape([batch_size, 1])\n        y.set_shape([batch_size, 1])\n        return {'x': x}, y\n\n      dataset = dataset.map(_map_fn)\n      return dataset\n\n    return _input_fn\n\n  def _make_model_fn(self, batch_size_dict, use_tpu_estimator_spec=False):\n\n    def _create_estimator_spec(mode, loss=None, predictions=None,\n                               export_outputs=None, eval_metrics=None,\n                               train_op=None):\n      if use_tpu_estimator_spec:\n        return tpu_estimator.TPUEstimatorSpec(\n            mode=mode,\n            loss=loss,\n            train_op=train_op,\n            predictions=predictions,\n            export_outputs=export_outputs,\n            eval_metrics=eval_metrics)\n      else:\n        return model_fn_lib.EstimatorSpec(\n            mode=mode,\n            loss=loss,\n            train_op=train_op,\n            predictions=predictions,\n            export_outputs=export_outputs,\n            eval_metric_ops=(eval_metrics[0](*eval_metrics[1]) if eval_metrics\n                             else None))\n\n    def _model_fn(features, labels, mode, params):\n      if not self._export_mode:\n        # Always check batch size in params\n        self.assertEqual(batch_size_dict[mode], params['batch_size'])\n      else:\n        self.assertNotIn('batch_size', params)\n\n      # Check the input feeds correct shape for train and eval. When eval on CPU\n      # or predict, it is allowed to have dynamic shape. So, here only validates\n      # the fully known shape (which covers the TPU train).\n      if features['x'].shape.is_fully_defined():\n        self.assertEqual(batch_size_dict[mode], features['x'].shape[0])\n\n      predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n          features['x'], 1,\n          kernel_initializer=tf.ones_initializer())\n      export_outputs = {\n          'predictions': export_output.RegressionOutput(predictions)\n      }\n\n      if mode == _PREDICT:\n        return _create_estimator_spec(\n            mode=mode,\n            predictions={'predictions': predictions},\n            export_outputs=export_outputs)\n\n      loss = tf.losses.mean_squared_error(labels, predictions)\n\n      optimizer = tf.tpu.CrossShardOptimizer(\n          tf.train.GradientDescentOptimizer(learning_rate=0.5))\n      train_op = optimizer.minimize(loss,\n                                    global_step=tf.train.get_global_step())\n\n      eval_metrics = (\n          lambda labels, predictions: {  # pylint: disable=g-long-lambda\n              'absolute_error': tf.metrics.mean_absolute_error(\n                  labels, predictions)},\n          [labels, predictions])\n      return _create_estimator_spec(\n          mode=mode,\n          loss=loss,\n          predictions={'predictions': predictions},\n          export_outputs=export_outputs,\n          train_op=train_op,\n          eval_metrics=eval_metrics)\n    return _model_fn\n\n  def _test_identity_savedmodel(self, export_dir):\n    with tf.Graph().as_default() as graph:\n      with tf.Session(graph=graph) as sess:\n        metagraph_def = tf.saved_model.loader.load(sess, [tf.saved_model.SERVING], export_dir)\n        fetch = metagraph_def.signature_def['predictions'].outputs['outputs']\n        feed = metagraph_def.signature_def['predictions'].inputs['inputs']\n        for x in self._data:\n          example = example_pb2.Example(\n              features=feature_pb2.Features(\n                  feature={\n                      'x':\n                          feature_pb2.Feature(\n                              float_list=feature_pb2.FloatList(\n                                  value=np.ravel(x)))\n                  })).SerializeToString()\n          y = sess.run(fetch.name, feed_dict={feed.name: [example]})\n          self.assertAlmostEqual(y, x[0], delta=0.01)\n\n  def test_complete_flow_with_per_core_input(self):\n    # Choose the train_batch_size divisible by 2 and 8 (common shards in test\n    # env) and batch_size for eval and predict prime number.\n    train_batch_size = 16\n    eval_batch_size = 16\n    predict_batch_size = 8\n\n    run_config = create_run_config(iterations_per_loop=4,\n                                   per_host_input_for_training=False)\n    num_shards = run_config.tpu_config.num_shards\n\n    (expected_batch_size_for_model_fn, expected_batch_size_for_input_fn,\n     expected_called_count_for_input_fn) = (\n         self._generate_expected_batch_size_and_called_count(\n             num_shards,\n             train_batch_size,\n             eval_batch_size,\n             predict_batch_size,\n             train_sharding_policy=_PER_SHARD,\n             eval_sharding_policy=_PER_HOST,\n             predict_sharding_policy=_PER_HOST))\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._make_model_fn(\n            expected_batch_size_for_model_fn, use_tpu_estimator_spec=True),\n        config=run_config,\n        train_batch_size=train_batch_size,\n        eval_batch_size=eval_batch_size,\n        predict_batch_size=predict_batch_size)\n\n    # TRAIN\n    # learn y = x\n    # Note: Gradients are all zero. Just testing execution.\n    def _input_fn(params):\n      dataset = self._make_input_fn(mode=_TRAIN, repeat=True)(params)\n      return tf.data.make_one_shot_iterator(dataset).get_next()\n\n    train_input_fn = _input_fn\n    est.train(train_input_fn, steps=7)\n\n    # EVALUTE\n    scores = est.evaluate(self._make_input_fn(mode=_EVAL), steps=6)\n    self.assertEqual(7, scores['global_step'])\n    self.assertGreater(0.1, scores['absolute_error'])\n\n    # PREDICT\n    predict_input_fn = self._make_input_fn(mode=_PREDICT, take=2)\n    predictions = [x['predictions'] for x in est.predict(predict_input_fn)]\n    self.assertAllClose(\n        self._data[:predict_batch_size * 2], predictions, atol=0.01)\n\n    # Verify all input_fn invoke recorded metadata.\n    self.assertInputFnCalledCountAndBatch(\n        expected_called_count_for_input_fn, expected_batch_size_for_input_fn)\n\n    # EXPORT\n    feature_spec = {'x': tf.io.FixedLenFeature([1], tf.float32)}\n    serving_input_receiver_fn = (\n        export.build_parsing_serving_input_receiver_fn(feature_spec))\n    with self.export_mode():\n      export_dir = est.export_saved_model(\n          tempfile.mkdtemp(dir=self.get_temp_dir()), serving_input_receiver_fn)\n    self.assertTrue(tf.gfile.Exists(export_dir))\n    self._test_identity_savedmodel(export_dir)\n\n  def test_complete_flow_with_per_host_input(self):\n    # Choose the train_batch_size divisible by 2 and 8 (common shards in test\n    # env) and batch_size for eval and predict prime number.\n    train_batch_size = 16\n    eval_batch_size = 16\n    predict_batch_size = 16\n\n    run_config = create_run_config(\n        iterations_per_loop=4, per_host_input_for_training=True)\n    num_shards = run_config.tpu_config.num_shards\n\n    (expected_batch_size_for_model_fn, expected_batch_size_for_input_fn,\n     expected_called_count_for_input_fn) = (\n         self._generate_expected_batch_size_and_called_count(\n             num_shards,\n             train_batch_size,\n             eval_batch_size,\n             predict_batch_size,\n             train_sharding_policy=_PER_HOST,\n             eval_sharding_policy=_PER_HOST,\n             predict_sharding_policy=_PER_HOST))\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._make_model_fn(\n            expected_batch_size_for_model_fn, use_tpu_estimator_spec=True),\n        config=run_config,\n        train_batch_size=train_batch_size,\n        eval_batch_size=eval_batch_size,\n        predict_batch_size=predict_batch_size)\n\n    # TRAIN\n    # learn y = x\n    # Note: Gradients are all zero. Just testing execution.\n    train_input_fn = self._make_input_fn(mode=_TRAIN, repeat=True)\n    est.train(train_input_fn, steps=7)\n\n    # EVALUTE\n    scores = est.evaluate(self._make_input_fn(mode=_EVAL), steps=6)\n    self.assertEqual(7, scores['global_step'])\n    self.assertGreater(0.1, scores['absolute_error'])\n\n    # PREDICT\n    predict_input_fn = self._make_input_fn(mode=_PREDICT, take=2)\n    predictions = [x['predictions'] for x in est.predict(predict_input_fn)]\n    self.assertAllClose(\n        self._data[:predict_batch_size * 2], predictions, atol=0.01)\n\n    # Verify all input_fn invoke recorded metadata.\n    self.assertInputFnCalledCountAndBatch(\n        expected_called_count_for_input_fn, expected_batch_size_for_input_fn)\n\n    # EXPORT\n    feature_spec = {'x': tf.io.FixedLenFeature([1], tf.float32)}\n    serving_input_receiver_fn = (\n        export.build_parsing_serving_input_receiver_fn(feature_spec))\n    with self.export_mode():\n      export_dir = est.export_saved_model(\n          tempfile.mkdtemp(dir=self.get_temp_dir()), serving_input_receiver_fn)\n    self.assertTrue(tf.gfile.Exists(export_dir))\n    self._test_identity_savedmodel(export_dir)\n\n  def test_complete_flow_with_eval_on_tpu(self):\n    # Choose the train_batch_size divisible by 2 and 8 (common shards in test\n    # env) and batch_size for eval and predict prime number.\n    train_batch_size = 16\n    eval_batch_size = 8\n    predict_batch_size = 8\n\n    run_config = create_run_config(iterations_per_loop=4)\n    num_shards = run_config.tpu_config.num_shards\n\n    (expected_batch_size_for_model_fn, expected_batch_size_for_input_fn,\n     expected_called_count_for_input_fn) = (\n         self._generate_expected_batch_size_and_called_count(\n             num_shards,\n             train_batch_size,\n             eval_batch_size,\n             predict_batch_size,\n             train_sharding_policy=_PER_HOST,\n             eval_sharding_policy=_PER_HOST,\n             predict_sharding_policy=_PER_HOST))\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._make_model_fn(\n            expected_batch_size_for_model_fn, use_tpu_estimator_spec=True),\n        config=run_config,\n        train_batch_size=train_batch_size,\n        eval_batch_size=eval_batch_size,\n        predict_batch_size=predict_batch_size)\n\n    # TRAIN\n    # learn y = x\n    # Note: Gradients are all zero. Just testing execution.\n    train_input_fn = self._make_input_fn(mode=_TRAIN, repeat=True)\n    est.train(train_input_fn, steps=7)\n\n    # EVALUTE\n    eval_input_fn = self._make_input_fn(mode=_EVAL, repeat=False)\n    scores = est.evaluate(eval_input_fn, steps=2)\n    self.assertEqual(7, scores['global_step'])\n    self.assertGreater(0.1, scores['absolute_error'])\n\n    # PREDICT\n    predict_input_fn = self._make_input_fn(mode=_PREDICT, take=2)\n    predictions = [x['predictions'] for x in est.predict(predict_input_fn)]\n    self.assertAllClose(\n        self._data[:predict_batch_size * 2], predictions, atol=0.01)\n\n    # Verify all input_fn invoke recorded metadata.\n    self.assertInputFnCalledCountAndBatch(\n        expected_called_count_for_input_fn, expected_batch_size_for_input_fn)\n\n    # EXPORT\n    feature_spec = {'x': tf.io.FixedLenFeature([1], tf.float32)}\n    serving_input_receiver_fn = (\n        export.build_parsing_serving_input_receiver_fn(feature_spec))\n    with self.export_mode():\n      export_dir = est.export_saved_model(\n          tempfile.mkdtemp(dir=self.get_temp_dir()), serving_input_receiver_fn)\n    self.assertTrue(tf.gfile.Exists(export_dir))\n    self._test_identity_savedmodel(export_dir)\n\n  def test_complete_flow_with_no_tpu(self):\n    # Choose the train_batch_size divisible by 2 and 8 (common shards in test\n    # env) and batch_size for eval and predict prime number.\n    train_batch_size = 16\n    eval_batch_size = 8\n    predict_batch_size = 1\n\n    run_config = create_run_config(iterations_per_loop=4)\n    num_shards = run_config.tpu_config.num_shards\n\n    (expected_batch_size_for_model_fn, expected_batch_size_for_input_fn,\n     expected_called_count_for_input_fn) = (\n         self._generate_expected_batch_size_and_called_count(\n             num_shards, train_batch_size, eval_batch_size, predict_batch_size,\n             train_sharding_policy=_UNSHARDED,\n             eval_sharding_policy=_UNSHARDED))\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._make_model_fn(\n            expected_batch_size_for_model_fn, use_tpu_estimator_spec=True),\n        config=run_config,\n        train_batch_size=train_batch_size,\n        eval_batch_size=eval_batch_size,\n        predict_batch_size=predict_batch_size,\n        use_tpu=False)\n\n    # TRAIN\n    # learn y = x\n    # Note: Gradients are all zero. Just testing execution.\n    train_input_fn = self._make_input_fn(mode=_TRAIN, repeat=True)\n    est.train(train_input_fn, steps=7)\n\n    # EVALUTE\n    eval_input_fn = self._make_input_fn(mode=_EVAL)\n    scores = est.evaluate(eval_input_fn, steps=2)\n    self.assertEqual(7, scores['global_step'])\n    self.assertGreater(0.1, scores['absolute_error'])\n\n    # PREDICT\n    predict_input_fn = self._make_input_fn(mode=_PREDICT)\n    predictions = [x['predictions'] for x in est.predict(predict_input_fn)]\n    self.assertAllClose(self._data, predictions, atol=0.01)\n\n    # Verify all input_fn invoke recorded metadata.\n    self.assertInputFnCalledCountAndBatch(\n        expected_called_count_for_input_fn, expected_batch_size_for_input_fn)\n\n    # EXPORT\n    feature_spec = {'x': tf.io.FixedLenFeature([1], tf.float32)}\n    serving_input_receiver_fn = (\n        export.build_parsing_serving_input_receiver_fn(feature_spec))\n    with self.export_mode():\n      export_dir = est.export_saved_model(\n          tempfile.mkdtemp(dir=self.get_temp_dir()), serving_input_receiver_fn)\n    self.assertTrue(tf.gfile.Exists(export_dir))\n    self._test_identity_savedmodel(export_dir)\n\n\nif __name__ == '__main__':\n  tf.disable_v2_behavior()\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator_model_parallelism_test.py",
    "content": "# Copyright 2021 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for TPUEstimator with model parallelism.\"\"\"\n\nfrom absl import flags\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.tpu import tpu_feed\nfrom tensorflow.python.tpu.device_assignment import device_assignment\nfrom tensorflow.python.tpu.topology import Topology\nfrom tensorflow.python.training import evaluation\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.export import export_output\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\nfrom tensorflow_estimator.python.estimator.util import tf_keras_v1\n# pylint: enable=g-direct-tensorflow-import\n\nFLAGS = flags.FLAGS\n\n_TRAIN = model_fn_lib.ModeKeys.TRAIN\n_EVAL = model_fn_lib.ModeKeys.EVAL\n_PREDICT = model_fn_lib.ModeKeys.PREDICT\n\n_PER_HOST = 'per_host_sharding'\n_PER_SHARD = 'per_shard_sharding'\n_UNSHARDED = 'unsharded'\n_INPUT_PIPELINE_WITH_QUEUE_RUNNER = (\n    'Input pipeline contains one or more QueueRunners')\n\n\ndef dense_computation(features):\n  return tf_keras_v1.__internal__.legacy.layers.dense(\n      features['x'], 1, kernel_initializer=tf.compat.v1.zeros_initializer())\n\n\ndef model_fn_global_step_incrementer(features, labels, mode, params):\n  del params\n  loss = None\n  train_op = None\n  predictions = dense_computation(features)\n  if mode != _PREDICT:\n    loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n    optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n        tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.5))\n    train_op = optimizer.minimize(loss, tf.compat.v1.train.get_global_step())\n  return tpu_estimator.TPUEstimatorSpec(\n      mode,\n      loss=loss,\n      train_op=train_op,\n      predictions={'predictions': predictions},\n      export_outputs={\n          'test': export_output.PredictOutput({\n              'prediction': predictions\n          })\n      })\n\n\ndef dummy_input_fn_with_dataset(batch_size, repeat=True, x=None):\n  if x is None:\n    x = np.random.normal(size=[batch_size, 1]).astype(np.float32)\n  labels = [[2.0]] * batch_size\n\n  dataset1 = tf.compat.v1.data.Dataset.from_tensor_slices(x)\n  dataset2 = tf.compat.v1.data.Dataset.from_tensor_slices(labels)\n  dataset = tf.compat.v1.data.Dataset.zip((dataset1, dataset2))\n  if repeat:\n    dataset = dataset.repeat()\n  dataset = dataset.batch(batch_size, drop_remainder=True)\n\n  def _map(x, y):\n    return {'x': x}, y\n\n  return dataset.map(_map)\n\n\ndef dummy_input_fn(batch_size, repeat=True):\n  dataset = dummy_input_fn_with_dataset(batch_size, repeat)\n  iterator = tf.compat.v1.data.make_one_shot_iterator(dataset)\n  return iterator.get_next()\n\n\ndef create_run_config(iterations_per_loop, num_shards, num_cores_per_replica,\n                      **kwargs):\n  return tpu_config.RunConfig(\n      master='',\n      tpu_config=tpu_config.TPUConfig(\n          iterations_per_loop=iterations_per_loop,\n          num_shards=num_shards,\n          num_cores_per_replica=num_cores_per_replica,\n          **kwargs))\n\n\nclass TPUEstimatorModelParallelismConstructorTest(tf.test.TestCase):\n\n  def test_fail_model_parallelism_for_per_core_input(self):\n    run_config = create_run_config(\n        iterations_per_loop=4,\n        num_shards=1,\n        num_cores_per_replica=2,\n        per_host_input_for_training=False)\n    with self.assertRaisesRegex(ValueError, 'Model parallelism only supports'):\n      tpu_estimator.TPUEstimator(\n          model_fn=model_fn_global_step_incrementer,\n          config=run_config,\n          train_batch_size=128)\n\n\nclass TPUEstimatorModelParallelismTrainingTest(tf.test.TestCase):\n\n  def _train_and_return_global_steps(self,\n                                     iterations_per_loop,\n                                     steps=None,\n                                     max_steps=None,\n                                     pre_train_steps=None,\n                                     **kwargs):\n    \"\"\"Trains the model and returns the list of global steps after each loop.\"\"\"\n\n    def input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    def _model_fn(features, labels, mode, params):\n      return model_fn_global_step_incrementer(features, labels, mode, params)\n\n    run_config = create_run_config(\n        iterations_per_loop=iterations_per_loop,\n        num_shards=1,\n        num_cores_per_replica=2,\n        **kwargs)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    class _TrainStepCheckHook(tf.compat.v1.train.SessionRunHook):\n      \"\"\"Check eval step counter after one session.run.\"\"\"\n\n      def __init__(self):\n        \"\"\"Constructs the run hook.\"\"\"\n        self._global_steps = []\n\n      @property\n      def global_steps(self):\n        return self._global_steps\n\n      def after_run(self, run_context, run_values):\n        global_step = run_context.session.run(tf.compat.v1.train.get_global_step())\n        self._global_steps.append(global_step)\n\n    if pre_train_steps:\n      est.train(input_fn, steps=pre_train_steps)\n\n    hook = _TrainStepCheckHook()\n    est.train(input_fn, steps=steps, max_steps=max_steps, hooks=[hook])\n    return hook.global_steps\n\n  def test_train_steps_with_model_parallelism(self):\n    # From scratch.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=40, steps=12)\n    self.assertEqual([12], global_steps_per_loop)\n\n    # From existing checkpoint.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=40, steps=12, pre_train_steps=3)\n    self.assertEqual([15], global_steps_per_loop)\n\n\nclass TPUEstimatorModelParallelismEvaluationTest(tf.test.TestCase):\n\n  def _create_input_fn(self):\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    return _input_fn\n\n  def _create_head(self, mode, loss, eval_metrics):\n    \"\"\"Creates a head returning `TPUEstimatorSpec` based on mode.\"\"\"\n    if mode == _EVAL:\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode, eval_metrics=eval_metrics, loss=loss)\n    # Train\n    optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n        tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.5))\n    train_op = optimizer.minimize(loss, global_step=tf.compat.v1.train.get_global_step())\n    return tpu_estimator.TPUEstimatorSpec(\n        mode=mode, train_op=train_op, loss=loss)\n\n  def _metric_fn_on_cpu(self, labels, predictions):\n    return {\n        'mse': tf.compat.v1.metrics.mean_absolute_error(labels, predictions),\n    }\n\n  def _model_fn_with_eval_tensor_list(self, features, labels, mode, params):\n    del params  # unused.\n    predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n        features['x'], 1, kernel_initializer=tf.compat.v1.zeros_initializer())\n    loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n\n    return self._create_head(\n        mode,\n        loss,\n        eval_metrics=(self._metric_fn_on_cpu, [labels, predictions]))\n\n  def _model_fn_with_eval_dict(self, features, labels, mode, params):\n    del params  # unused.\n    predictions = tf_keras_v1.__internal__.legacy.layers.dense(\n        features['x'], 1, kernel_initializer=tf.compat.v1.zeros_initializer())\n    loss = tf.compat.v1.losses.mean_squared_error(labels, predictions)\n\n    return self._create_head(\n        mode,\n        loss,\n        eval_metrics=(self._metric_fn_on_cpu, {\n            'labels': labels,\n            'predictions': predictions\n        }))\n\n  def _test_eval_steps(self, expected_eval_steps, iterations):\n\n    run_config = create_run_config(\n        iterations_per_loop=iterations, num_shards=1, num_cores_per_replica=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_tensor_list,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    est.train(self._create_input_fn(), steps=1)\n\n    class _EvalStepCheckHook(tf.compat.v1.train.SessionRunHook):\n      \"\"\"Check eval step counter after one session.run.\n\n      As the evaluation sets the eval iterations as the eval steps, the\n      after_run should be invoked only once.\n      \"\"\"\n\n      def __init__(self, iterations_per_loop, test_case):\n        \"\"\"Constructs the run hook.\"\"\"\n        self._iterations = iterations_per_loop\n        self._invoked = False\n        self._test_case = test_case\n\n      def before_run(self, run_context):\n        return tf.compat.v1.train.SessionRunArgs({\n            'eval_steps': evaluation._get_or_create_eval_step()\n        })\n\n      def after_run(self, run_context, run_values):\n        eval_steps = run_values.results['eval_steps']\n        self._test_case.assertEqual(expected_eval_steps, eval_steps)\n        self._test_case.assertFalse(self._invoked)\n        self._invoked = True\n\n    est.evaluate(\n        self._create_input_fn(),\n        steps=expected_eval_steps,\n        hooks=[_EvalStepCheckHook(iterations, self)])\n\n  def test_eval_metrics_with_tensor_list(self):\n    run_config = create_run_config(\n        iterations_per_loop=2, num_shards=1, num_cores_per_replica=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_tensor_list,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_eval_metrics_with_dict(self):\n    run_config = create_run_config(\n        iterations_per_loop=2, num_shards=1, num_cores_per_replica=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_dict,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    est.train(self._create_input_fn(), steps=1)\n    est.evaluate(self._create_input_fn(), steps=1)\n\n  def test_fail_with_wrong_num_shards(self):\n    run_config = create_run_config(\n        iterations_per_loop=2, num_shards=2, num_cores_per_replica=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn_with_eval_tensor_list,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    with self.assertRaisesRegex(ValueError, 'num_shards is not set correctly'):\n      est.train(self._create_input_fn(), steps=1)\n\n\nclass TPUEstimatorModelParallelismInFeedTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._topology_2x2x2 = Topology(\n        device_coordinates=np.array(\n            [[[0, 0, 0, 0], [0, 0, 0, 1], [0, 1, 0, 0], [0, 1, 0, 1],\n              [1, 0, 0, 0], [1, 0, 0, 1], [1, 1, 0, 0], [1, 1, 0, 1]]],\n            dtype=np.int32),\n        mesh_shape=np.array([2, 2, 1, 2], dtype=np.int32))\n\n  def test_infeed_even_partition(self):\n    \"\"\"Tests even infeed tensors partition.\"\"\"\n    ds = device_assignment(\n        self._topology_2x2x2, num_replicas=1, computation_shape=[1, 1, 1, 2])\n    input_partition_dims = [[2, 1]]\n    # pylint: disable=protected-access\n    partitioned_infeed = tpu_feed._PartitionedInfeedQueue(\n        number_of_tuple_elements=1,\n        host_id=0,\n        input_partition_dims=input_partition_dims,\n        device_assignment=ds)\n    x = tf.zeros((14, 5))\n    tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host(\n        x, dims=input_partition_dims[0])\n    self.assertEqual(2, len(tensors))\n    self.assertEqual([(7, 5), (7, 5)], [t.shape for t in tensors])\n    # pylint: enable=protected-access\n\n  def test_infeed_uneven_partition(self):\n    \"\"\"Tests uneven infeed tensors partition.\"\"\"\n    ds = device_assignment(\n        self._topology_2x2x2, num_replicas=1, computation_shape=[2, 2, 1, 2])\n    input_partition_dims = [[4, 2]]\n    # pylint: disable=protected-access\n    partitioned_infeed = tpu_feed._PartitionedInfeedQueue(\n        number_of_tuple_elements=1,\n        host_id=0,\n        input_partition_dims=input_partition_dims,\n        device_assignment=ds)\n    x = tf.zeros((14, 5))\n    tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host(\n        x, dims=input_partition_dims[0])\n    self.assertEqual(8, len(tensors))\n    self.assertEqual((2, 2), tensors[-1].shape)\n    # pylint: enable=protected-access\n\n  def test_infeed_tailing_zero_partition(self):\n    \"\"\"Tests infeed tensors partition which causes zero-size tensors.\"\"\"\n    ds = device_assignment(\n        self._topology_2x2x2, num_replicas=1, computation_shape=[1, 2, 1, 2])\n    input_partition_dims = [[4, 1]]\n    # pylint: disable=protected-access\n    partitioned_infeed = tpu_feed._PartitionedInfeedQueue(\n        number_of_tuple_elements=1,\n        host_id=0,\n        input_partition_dims=input_partition_dims,\n        device_assignment=ds)\n    x = tf.zeros((5, 5))\n    tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host(\n        x, dims=input_partition_dims[0])\n    self.assertEqual(4, len(tensors))\n    self.assertEqual((1, 5), tensors[2].shape)\n    self.assertEqual((0, 5), tensors[3].shape)\n    # pylint: enable=protected-access\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator_signals_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"TPU Estimator Signalling Tests.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\n\n\ndef make_input_fn(num_samples):\n  a = np.linspace(0, 100.0, num=num_samples)\n  b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))\n\n  def input_fn(params):\n    batch_size = params['batch_size']\n    da1 = tf.compat.v1.data.Dataset.from_tensor_slices(a)\n    da2 = tf.compat.v1.data.Dataset.from_tensor_slices(b)\n\n    dataset = tf.compat.v1.data.Dataset.zip((da1, da2))\n    dataset = dataset.map(lambda fa, fb: {'a': fa, 'b': fb})\n    dataset = dataset.batch(batch_size)\n    return dataset\n\n  return input_fn, (a, b)\n\n\ndef make_input_fn_with_labels(num_samples):\n  a = np.linspace(0, 100.0, num=num_samples)\n  b = np.reshape(np.array(a, dtype=np.float32), (len(a), 1))\n\n  def input_fn(params):\n    batch_size = params['batch_size']\n    da1 = tf.compat.v1.data.Dataset.from_tensor_slices(a)\n    da2 = tf.compat.v1.data.Dataset.from_tensor_slices(b)\n\n    dataset = tf.compat.v1.data.Dataset.zip((da1, da2))\n    dataset = dataset.map(lambda fa, fb: ({'a': fa}, fb))\n    dataset = dataset.batch(batch_size)\n    return dataset\n\n  return input_fn, (a, b)\n\n\nclass TPUEstimatorStoppingSignalsTest(tf.test.TestCase):\n\n  def test_normal_output_without_signals(self):\n    num_samples = 4\n    batch_size = 2\n\n    params = {'batch_size': batch_size}\n    input_fn, (a, b) = make_input_fn(num_samples=num_samples)\n\n    with tf.Graph().as_default():\n      dataset = input_fn(params)\n      features = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next()\n\n      # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.\n      self.assertIsNone(features['a'].shape.as_list()[0])\n\n      with tf.compat.v1.Session() as sess:\n        result = sess.run(features)\n        self.assertAllEqual(a[:batch_size], result['a'])\n        self.assertAllEqual(b[:batch_size], result['b'])\n\n        # This run should work as num_samples / batch_size = 2.\n        result = sess.run(features)\n        self.assertAllEqual(a[batch_size:num_samples], result['a'])\n        self.assertAllEqual(b[batch_size:num_samples], result['b'])\n\n        with self.assertRaises(tf.errors.OutOfRangeError):\n          # Given num_samples and batch_size, this run should fail.\n          sess.run(features)\n\n  def test_output_with_stopping_signals(self):\n    num_samples = 4\n    batch_size = 2\n\n    params = {'batch_size': batch_size}\n    input_fn, (a, b) = make_input_fn(num_samples=num_samples)\n\n    with tf.Graph().as_default():\n      dataset = input_fn(params)\n      inputs = tpu_estimator._InputsWithStoppingSignals(dataset, batch_size)\n      dataset_initializer = inputs.dataset_initializer()\n      features, _ = inputs.features_and_labels()\n      signals = inputs.signals()\n\n      # With tf.data.Dataset.batch, the batch is None, i.e., dynamic shape.\n      self.assertIsNone(features['a'].shape.as_list()[0])\n\n      with tf.compat.v1.Session() as sess:\n        sess.run(dataset_initializer)\n\n        result, evaluated_signals = sess.run([features, signals])\n        self.assertAllEqual(a[:batch_size], result['a'])\n        self.assertAllEqual(b[:batch_size], result['b'])\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n\n        # This run should work as num_samples / batch_size = 2.\n        result, evaluated_signals = sess.run([features, signals])\n        self.assertAllEqual(a[batch_size:num_samples], result['a'])\n        self.assertAllEqual(b[batch_size:num_samples], result['b'])\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n\n        # This run should work, *but* see STOP ('1') as signals\n        _, evaluated_signals = sess.run([features, signals])\n        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])\n\n        with self.assertRaises(tf.errors.OutOfRangeError):\n          sess.run(features)\n\n\nclass TPUEstimatorStoppingSignalsWithPaddingTest(tf.test.TestCase):\n\n  def test_num_samples_divisible_by_batch_size(self):\n    num_samples = 4\n    batch_size = 2\n\n    params = {'batch_size': batch_size}\n    input_fn, (a, b) = make_input_fn(num_samples=num_samples)\n\n    with tf.Graph().as_default():\n      dataset = input_fn(params)\n      inputs = tpu_estimator._InputsWithStoppingSignals(\n          dataset, batch_size, add_padding=True)\n      dataset_initializer = inputs.dataset_initializer()\n      features, _ = inputs.features_and_labels()\n      signals = inputs.signals()\n\n      # With padding, all shapes are static now.\n      self.assertEqual(batch_size, features['a'].shape.as_list()[0])\n\n      with tf.compat.v1.Session() as sess:\n        sess.run(dataset_initializer)\n\n        result, evaluated_signals = sess.run([features, signals])\n        self.assertAllEqual(a[:batch_size], result['a'])\n        self.assertAllEqual(b[:batch_size], result['b'])\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n        self.assertAllEqual([0.] * batch_size,\n                            evaluated_signals['padding_mask'])\n\n        # This run should work as num_samples / batch_size = 2.\n        result, evaluated_signals = sess.run([features, signals])\n        self.assertAllEqual(a[batch_size:num_samples], result['a'])\n        self.assertAllEqual(b[batch_size:num_samples], result['b'])\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n        self.assertAllEqual([0.] * batch_size,\n                            evaluated_signals['padding_mask'])\n\n        # This run should work, *but* see STOP ('1') as signals\n        _, evaluated_signals = sess.run([features, signals])\n        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])\n\n        with self.assertRaises(tf.errors.OutOfRangeError):\n          sess.run(features)\n\n  def test_num_samples_not_divisible_by_batch_size(self):\n    num_samples = 5\n    batch_size = 2\n\n    params = {'batch_size': batch_size}\n    input_fn, (a, b) = make_input_fn_with_labels(num_samples=num_samples)\n\n    with tf.Graph().as_default():\n      dataset = input_fn(params)\n      inputs = tpu_estimator._InputsWithStoppingSignals(\n          dataset, batch_size, add_padding=True)\n      dataset_initializer = inputs.dataset_initializer()\n      features, labels = inputs.features_and_labels()\n      signals = inputs.signals()\n\n      # With padding, all shapes are static.\n      self.assertEqual(batch_size, features['a'].shape.as_list()[0])\n\n      with tf.compat.v1.Session() as sess:\n        sess.run(dataset_initializer)\n\n        evaluated_features, evaluated_labels, evaluated_signals = (\n            sess.run([features, labels, signals]))\n        self.assertAllEqual(a[:batch_size], evaluated_features['a'])\n        self.assertAllEqual(b[:batch_size], evaluated_labels)\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n        self.assertAllEqual([0.] * batch_size,\n                            evaluated_signals['padding_mask'])\n\n        # This run should work as num_samples / batch_size >= 2.\n        evaluated_features, evaluated_labels, evaluated_signals = (\n            sess.run([features, labels, signals]))\n        self.assertAllEqual(a[batch_size:2 * batch_size],\n                            evaluated_features['a'])\n        self.assertAllEqual(b[batch_size:2 * batch_size], evaluated_labels)\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n        self.assertAllEqual([0.] * batch_size,\n                            evaluated_signals['padding_mask'])\n\n        # This is the final partial batch.\n        evaluated_features, evaluated_labels, evaluated_signals = (\n            sess.run([features, labels, signals]))\n        real_batch_size = num_samples % batch_size\n\n        # Assert the real part.\n        self.assertAllEqual(a[2 * batch_size:num_samples],\n                            evaluated_features['a'][:real_batch_size])\n        self.assertAllEqual(b[2 * batch_size:num_samples],\n                            evaluated_labels[:real_batch_size])\n        # Assert the padded part.\n        self.assertAllEqual([0.0] * (batch_size - real_batch_size),\n                            evaluated_features['a'][real_batch_size:])\n        self.assertAllEqual([[0.0]] * (batch_size - real_batch_size),\n                            evaluated_labels[real_batch_size:])\n\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n\n        padding = ([.0] * real_batch_size + [1.] *\n                   (batch_size - real_batch_size))\n        self.assertAllEqual(padding, evaluated_signals['padding_mask'])\n\n        # This run should work, *but* see STOP ('1') as signals\n        _, evaluated_signals = sess.run([features, signals])\n        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])\n\n        with self.assertRaises(tf.errors.OutOfRangeError):\n          sess.run(features)\n\n  def test_slice(self):\n    num_samples = 3\n    batch_size = 2\n\n    params = {'batch_size': batch_size}\n    input_fn, (a, b) = make_input_fn(num_samples=num_samples)\n\n    with tf.Graph().as_default():\n      dataset = input_fn(params)\n      inputs = tpu_estimator._InputsWithStoppingSignals(\n          dataset, batch_size, add_padding=True)\n      dataset_initializer = inputs.dataset_initializer()\n      features, _ = inputs.features_and_labels()\n      signals = inputs.signals()\n\n      sliced_features = (\n          tpu_estimator._PaddingSignals.slice_tensor_or_dict(features, signals))\n\n      with tf.compat.v1.Session() as sess:\n        sess.run(dataset_initializer)\n\n        result, evaluated_signals = sess.run([sliced_features, signals])\n        self.assertAllEqual(a[:batch_size], result['a'])\n        self.assertAllEqual(b[:batch_size], result['b'])\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n\n        # This is the final partial batch.\n        result, evaluated_signals = sess.run([sliced_features, signals])\n        self.assertEqual(1, len(result['a']))\n        self.assertAllEqual(a[batch_size:num_samples], result['a'])\n        self.assertAllEqual(b[batch_size:num_samples], result['b'])\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n\n        # This run should work, *but* see STOP ('1') as signals\n        _, evaluated_signals = sess.run([sliced_features, signals])\n        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])\n\n        with self.assertRaises(tf.errors.OutOfRangeError):\n          sess.run(sliced_features)\n\n  def test_slice_with_multi_invocations_per_step(self):\n    num_samples = 3\n    batch_size = 2\n\n    params = {'batch_size': batch_size}\n    input_fn, (a, b) = make_input_fn(num_samples=num_samples)\n\n    with tf.Graph().as_default():\n      dataset = input_fn(params)\n      inputs = tpu_estimator._InputsWithStoppingSignals(\n          dataset, batch_size, add_padding=True, num_invocations_per_step=2)\n      dataset_initializer = inputs.dataset_initializer()\n      features, _ = inputs.features_and_labels()\n      signals = inputs.signals()\n\n      sliced_features = (\n          tpu_estimator._PaddingSignals.slice_tensor_or_dict(features, signals))\n\n      with tf.compat.v1.Session() as sess:\n        sess.run(dataset_initializer)\n\n        result, evaluated_signals = sess.run([sliced_features, signals])\n        self.assertAllEqual(a[:batch_size], result['a'])\n        self.assertAllEqual(b[:batch_size], result['b'])\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n\n        # This is the final partial batch.\n        result, evaluated_signals = sess.run([sliced_features, signals])\n        self.assertEqual(1, len(result['a']))\n        self.assertAllEqual(a[batch_size:num_samples], result['a'])\n        self.assertAllEqual(b[batch_size:num_samples], result['b'])\n        self.assertAllEqual([[0.]] * batch_size, evaluated_signals['stopping'])\n\n        # We should see 3 continuous batches with STOP ('1') as signals and all\n        # of them have mask 1.\n        _, evaluated_signals = sess.run([sliced_features, signals])\n        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])\n        self.assertAllEqual([1.] * batch_size,\n                            evaluated_signals['padding_mask'])\n\n        _, evaluated_signals = sess.run([sliced_features, signals])\n        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])\n        self.assertAllEqual([1.] * batch_size,\n                            evaluated_signals['padding_mask'])\n\n        _, evaluated_signals = sess.run([sliced_features, signals])\n        self.assertAllEqual([[1.]] * batch_size, evaluated_signals['stopping'])\n        self.assertAllEqual([1.] * batch_size,\n                            evaluated_signals['padding_mask'])\n        with self.assertRaises(tf.errors.OutOfRangeError):\n          sess.run(sliced_features)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/tpu_estimator_test.py",
    "content": "# Copyright 2020 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for TPUEstimator.\n\nTo improve the performance, the test has been splitted into multiple parts\n\n1. Integration         tpu_estimator_integration_test\n2. Model Parallellsim  tpu_estimator_model_parallelism_test\n3. Evaluation          tpu_estimator_evaluation_test\n4. Export              tpu_estimator_export_test\n5. Input Host v2       tpu_estimator_input_v2_test\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport functools\nimport os\nimport re\nimport tempfile\n\nfrom absl import flags\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\n\n# pylint: disable=g-direct-tensorflow-import\nfrom tensorflow.core.protobuf import cluster_pb2\nfrom tensorflow.core.protobuf import config_pb2\nfrom tensorflow.core.util import event_pb2\nfrom tensorflow.python import data as dataset_lib\nfrom tensorflow.python.data.ops import dataset_ops\nfrom tensorflow.python.framework import constant_op\nfrom tensorflow.python.framework import dtypes\nfrom tensorflow.python.framework import errors\nfrom tensorflow.python.framework import ops\nfrom tensorflow.python.layers import layers\nfrom tensorflow.python.lib.io import tf_record\nfrom tensorflow.python.ops import array_ops\nfrom tensorflow.python.ops import init_ops\nfrom tensorflow.python.ops import math_ops\nfrom tensorflow.python.ops import metrics as metrics_lib\nfrom tensorflow.python.ops import parsing_ops\nfrom tensorflow.python.ops import state_ops\nfrom tensorflow.python.ops import string_ops\nfrom tensorflow.python.ops import summary_ops_v2\nfrom tensorflow.python.ops import variable_scope\nfrom tensorflow.python.ops.gen_array_ops import reshape\nfrom tensorflow.python.ops.losses import losses\nfrom tensorflow.python.ops.random_ops import random_uniform\nfrom tensorflow.python.platform import gfile\nfrom tensorflow.python.platform import test\nfrom tensorflow.python.platform import tf_logging as logging\nfrom tensorflow.python.saved_model import signature_constants\nfrom tensorflow.python.summary import summary as summary_lib\nfrom tensorflow.python.tpu import topology as tf_topology\nfrom tensorflow.python.tpu import tpu_system_metadata as tpu_system_metadata_lib\nfrom tensorflow.python.training import moving_averages\nfrom tensorflow.python.training import session_run_hook\nfrom tensorflow.python.training import training\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator.util import tf_keras\nfrom tensorflow_estimator.python.estimator.export import export\nfrom tensorflow_estimator.python.estimator.export import export_output as export_output_lib\nfrom tensorflow_estimator.python.estimator.inputs import numpy_io\nfrom tensorflow_estimator.python.estimator.tpu import tpu_config\nfrom tensorflow_estimator.python.estimator.tpu import tpu_estimator\n# pylint: enable=g-direct-tensorflow-import\n\n\nflags.DEFINE_integer('test_num_shards', 8, 'number of replicas to test')\n\nFLAGS = flags.FLAGS\n\n_TRAIN = model_fn_lib.ModeKeys.TRAIN\n_EVAL = model_fn_lib.ModeKeys.EVAL\n_PREDICT = model_fn_lib.ModeKeys.PREDICT\n\n_PER_HOST = 'per_host_sharding'\n_PER_SHARD = 'per_shard_sharding'\n_UNSHARDED = 'unsharded'\n_INPUT_PIPELINE_WITH_QUEUE_RUNNER = (\n    'Input pipeline contains one or more QueueRunners')\n\n\ndef events_from_file(filepath):\n  \"\"\"Returns all events in a single event file.\n\n  Args:\n    filepath: Path to the event file.\n\n  Returns:\n    A list of all tf.compat.v1.Event protos in the event file.\n  \"\"\"\n  records = list(tf_record.tf_record_iterator(filepath))\n  result = []\n  for r in records:\n    event = event_pb2.Event()\n    event.ParseFromString(r)\n    result.append(event)\n  return result\n\n\ndef dense_computation(features):\n  x = features['x']\n  if len(x.get_shape().as_list()) == 4:\n    x = math_ops.reduce_sum(x, axis=[1, 2])\n  return layers.dense(x, 1, kernel_initializer=init_ops.zeros_initializer())\n\n\ndef get_model_fn(export_tpu_tensor=True,\n                 export_cpu_tensor=False,\n                 tpu_estimator_spec=True):\n\n  def model_fn(features, labels, mode, params):\n    del params\n    loss = None\n    train_op = None\n\n    predictions = dense_computation(features)\n    export_outputs = None\n    if mode != _PREDICT:\n      loss = losses.mean_squared_error(labels, predictions)\n      optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n          training.GradientDescentOptimizer(learning_rate=0.5))\n      train_op = optimizer.minimize(loss, training.get_global_step())\n    else:\n      if export_tpu_tensor:\n        key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY\n        export_outputs = {\n            key: export_output_lib.PredictOutput({\n                'prediction': predictions\n            })\n        }\n      else:\n        export_outputs = {}\n\n      if export_cpu_tensor:\n\n        def host_call(predictions):\n          classes = string_ops.as_string(predictions, name='classes')\n          classification_output = export_output_lib.ClassificationOutput(\n              classes=classes)\n          export_outputs['classification'] = classification_output\n\n        tf.compat.v1.tpu.outside_compilation(host_call, predictions)\n\n    if tpu_estimator_spec:\n      spec_type = tpu_estimator.TPUEstimatorSpec\n    else:\n      spec_type = model_fn_lib.EstimatorSpec\n\n    return spec_type(\n        mode,\n        loss=loss,\n        train_op=train_op,\n        predictions={'predictions': predictions},\n        export_outputs=export_outputs)\n\n  return model_fn\n\n\ndef dummy_input_fn_with_dataset(dataset_size,\n                                repeat=True,\n                                x=None,\n                                batch_size=None):\n  if batch_size is None:\n    batch_size = dataset_size\n  if x is None:\n    x = np.random.normal(size=[dataset_size, 1]).astype(np.float32)\n  labels = [[2.0]] * dataset_size\n\n  dataset1 = dataset_lib.Dataset.from_tensor_slices(x)\n  dataset2 = dataset_lib.Dataset.from_tensor_slices(labels)\n  dataset = dataset_lib.Dataset.zip((dataset1, dataset2))\n  if repeat:\n    dataset = dataset.repeat()\n  dataset = dataset.batch(batch_size, drop_remainder=True)\n\n  def _map(x, y):\n    return {'x': x}, y\n\n  return dataset.map(_map)\n\n\ndef dummy_input_fn(batch_size, repeat=True):\n  dataset = dummy_input_fn_with_dataset(batch_size, repeat)\n  iterator = dataset_ops.make_one_shot_iterator(dataset)\n  return iterator.get_next()\n\n\ndef create_run_config(iterations_per_loop, num_shards=None, **kwargs):\n  return tpu_config.RunConfig(\n      master='',\n      tpu_config=tpu_config.TPUConfig(\n          iterations_per_loop=iterations_per_loop,\n          num_shards=num_shards if num_shards else FLAGS.test_num_shards,\n          **kwargs),\n  )\n\n\nclass TPUEstimatorConstructorTest(test.TestCase):\n\n  def test_reserved_key(self):\n    run_config = create_run_config(iterations_per_loop=4)\n    params = {'batch_size': 128}\n    with self.assertRaisesRegex(ValueError, 'are reserved keys'):\n      tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(), config=run_config, params=params)\n\n  def test_missing_train_batch_size(self):\n    run_config = create_run_config(iterations_per_loop=4)\n    with self.assertRaisesRegex(ValueError,\n                                '`train_batch_size` cannot be `None`'):\n      tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(), config=run_config, params={})\n\n  def test_invalid_batch_size(self):\n    run_config = create_run_config(iterations_per_loop=4)\n    with self.assertRaisesRegex(TypeError, 'must be int'):\n      tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(), config=run_config, train_batch_size=1.0)\n\n  def test_batch_size_with_num_shards_for_per_core_input(self):\n    input_fn_call_count = [0]\n    run_config = create_run_config(\n        iterations_per_loop=4, per_host_input_for_training=False)\n    num_shards = run_config.tpu_config.num_shards\n\n    def _input_fn(params):\n      input_fn_call_count[0] += 1\n      self.assertEqual(128 // num_shards, params['batch_size'])\n      return dummy_input_fn(params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=128)\n    self.assertEqual(0, input_fn_call_count[0])\n    est.train(_input_fn, steps=1)\n    self.assertEqual(num_shards, input_fn_call_count[0])\n\n  def test_batch_size_with_num_shards_for_per_host_input(self):\n    input_fn_call_count = [0]\n    run_config = create_run_config(\n        iterations_per_loop=4, per_host_input_for_training=True)\n\n    def _input_fn(params):\n      input_fn_call_count[0] += 1\n      self.assertEqual(128, params['batch_size'])\n      return dummy_input_fn(params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=128)\n    self.assertEqual(0, input_fn_call_count[0])\n    est.train(_input_fn, steps=1)\n    self.assertEqual(1, input_fn_call_count[0])\n\n  def test_train_batch_size_with_non_divisible_num_shards(self):\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=127)\n    with self.assertRaisesRegex(ValueError, 'train.*must be divisible'):\n      est.train(dummy_input_fn_with_dataset, steps=1)\n\n  def test_train_batch_size_with_non_divisible_num_shards_broadcast_mode(self):\n    input_fn_call_count = [0]\n    run_config = create_run_config(\n        iterations_per_loop=4,\n        per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)\n\n    def _input_fn(params):\n      input_fn_call_count[0] += 1\n      self.assertEqual(127, params['batch_size'])\n      return dummy_input_fn(params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=127)\n    self.assertEqual(0, input_fn_call_count[0])\n    est.train(_input_fn, steps=1)\n    self.assertEqual(1, input_fn_call_count[0])\n\n  def test_eval_batch_size_with_non_divisible_num_shards(self):\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(),\n        config=run_config,\n        train_batch_size=64,\n        eval_batch_size=127)\n    with self.assertRaisesRegex(ValueError, 'eval.*must be divisible'):\n      est.evaluate(dummy_input_fn_with_dataset, steps=1)\n\n  def test_predict_batch_size_with_non_divisible_num_shards_broadcast_mode(\n      self):\n    run_config = create_run_config(\n        iterations_per_loop=4,\n        per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)\n\n    def _input_fn(params):\n      return dummy_input_fn_with_dataset(params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(),\n        config=run_config,\n        train_batch_size=64,\n        predict_batch_size=127)\n    est.train(_input_fn, steps=1)\n    est.predict(_input_fn)\n\n  def test_predict_batch_size_with_non_divisible_num_shards(self):\n    run_config = create_run_config(iterations_per_loop=4)\n\n    def _input_fn(params):\n      return dummy_input_fn_with_dataset(params['batch_size'])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(),\n        config=run_config,\n        train_batch_size=64,\n        predict_batch_size=127)\n    est.train(_input_fn, steps=1)\n    with self.assertRaisesRegex(ValueError, 'predict.*must be divisible'):\n      list(est.predict(_input_fn))\n\n  def test_invalid_num_shards(self):\n    run_config = tpu_config.RunConfig(\n        master='',\n        tpu_config=tpu_config.TPUConfig(iterations_per_loop=2, num_shards=16))\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=128)\n    with self.assertRaisesRegex(ValueError, 'num_shards is not set correctly'):\n      est.train(dummy_input_fn_with_dataset, steps=1)\n\n\nclass TPUEstimatorTPUContextTest(test.TestCase):\n\n  def test_context_replicas(self):\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      context = params['context']\n      self.assertEqual(FLAGS.test_num_shards, context.num_replicas)\n      self.assertEqual(1, context.num_hosts)\n      self.assertEqual(0, context.current_host)\n      self.assertEqual(FLAGS.test_num_shards, context.num_of_replicas_per_host)\n      return dummy_input_fn(batch_size)\n\n    run_config = create_run_config(\n        iterations_per_loop=4, per_host_input_for_training=False)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=16)\n    est.train(_input_fn, steps=4)\n\n  def _query_system(self, master_address, cluster_def, query_topology):\n    del master_address, cluster_def, query_topology\n    # construct an ideal, not real, topology for 4x4.\n    topology = tf_topology.Topology(\n        mesh_shape=[4, 4, 1, 2],\n        device_coordinates=[\n            [\n                [0, 0, 0, 0],\n                [0, 0, 0, 1],\n                [1, 0, 0, 0],\n                [1, 0, 0, 1],\n                [2, 0, 0, 0],\n                [2, 0, 0, 1],\n                [3, 0, 0, 0],\n                [3, 0, 0, 1],\n            ],\n            [\n                [0, 1, 0, 0],\n                [0, 1, 0, 1],\n                [1, 1, 0, 0],\n                [1, 1, 0, 1],\n                [2, 1, 0, 0],\n                [2, 1, 0, 1],\n                [3, 1, 0, 0],\n                [3, 1, 0, 1],\n            ],\n            [\n                [0, 2, 0, 0],\n                [0, 2, 0, 1],\n                [1, 2, 0, 0],\n                [1, 2, 0, 1],\n                [2, 2, 0, 0],\n                [2, 2, 0, 1],\n                [3, 2, 0, 0],\n                [3, 2, 0, 1],\n            ],\n            [\n                [0, 3, 0, 0],\n                [0, 3, 0, 1],\n                [1, 3, 0, 0],\n                [1, 3, 0, 1],\n                [2, 3, 0, 0],\n                [2, 3, 0, 1],\n                [3, 3, 0, 0],\n                [3, 3, 0, 1],\n            ],\n        ],\n    )\n    return tpu_system_metadata_lib.TPUSystemMetadata(\n        num_cores=32,\n        num_hosts=4,\n        num_of_cores_per_host=8,\n        topology=topology,\n        devices=[])\n\n  def test_num_cores_per_replica_is_not_greater_than_num_cores_per_host(self):\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n\n      FLAGS.test_num_shards = 2\n      run_config = create_run_config(\n          iterations_per_loop=1, num_cores_per_replica=16)\n\n      with self.assertRaisesRegex(\n          ValueError,\n          'Except the PER_HOST_V2 mode, the num of cores required by '\n          'model parallelism specified by TPUConfig.num_cores_per_replica '\n          'should be less than or equal to the num_cores_per_host. '\n          'num_cores_per_replica: 16, num_cores_per_host: 8'):\n        est = tpu_estimator.TPUEstimator(\n            model_fn=get_model_fn(), config=run_config, train_batch_size=64)\n        est.train(_input_fn, steps=1)\n\n  def test_device_for_replica_fn(self):\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      context = params['context']\n\n      with self.assertRaisesRegex(\n          RuntimeError, 'This TPUContext instance must not be '\n          'called from input_fn.'):\n        context.device_assignment()\n\n      for replica_id in range(context.num_replicas):\n        (host_device, ordinal_id) = context.device_for_replica(replica_id)\n        self.assertEqual('/task:0/device:CPU:0', host_device)\n        self.assertEqual(ordinal_id, replica_id)\n\n      return dummy_input_fn(batch_size)\n\n    run_config = create_run_config(\n        iterations_per_loop=4, per_host_input_for_training=True)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=16)\n    est.train(_input_fn, steps=4)\n\n  def test_input_deployment_for_per_host(self):\n    fake_num_cores = 32\n    fake_num_hosts = 4\n    fake_num_cores_per_host = fake_num_cores // fake_num_hosts\n    invocation_count = [0]\n    global_batch_size = 16 * fake_num_cores\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      self.assertEqual(global_batch_size // fake_num_hosts, batch_size)\n\n      context = params['context']\n      current_invocation_count = invocation_count[0]\n\n      (current_input_device, invocation_index_in_context, total_invocations,\n       replicas_consumed_by_current_invocation) = (\n           context.current_input_fn_deployment())\n      self.assertEqual('/replica:0/task:0/device:CPU:0', current_input_device)\n      self.assertEqual(current_invocation_count, invocation_index_in_context)\n      self.assertEqual(current_invocation_count, context.current_host)\n      self.assertEqual(fake_num_hosts, total_invocations)\n      self.assertEqual(fake_num_cores_per_host,\n                       replicas_consumed_by_current_invocation)\n\n      # Use the invocation_count to track the number of invocations.\n      invocation_count[0] = current_invocation_count + 1\n\n      return dummy_input_fn(batch_size)\n\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n\n      run_config = create_run_config(\n          iterations_per_loop=4,\n          num_shards=fake_num_cores,\n          per_host_input_for_training=True)\n\n      est = tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(),\n          config=run_config,\n          train_batch_size=global_batch_size)\n\n      # This exception is ok as we do not have sufficient TPU cores to run the\n      # model. as far as the assert after it is correct, input pipeline checking\n      # is done and successful.\n      with self.assertRaisesRegex(errors.InvalidArgumentError,\n                                  'there are only 2 cores in the TPU topology'):\n        est.train(_input_fn, steps=4)\n\n      self.assertEqual(fake_num_hosts, invocation_count[0])\n\n  def test_input_deployment_for_per_host_v2(self):\n    fake_num_cores = 32\n    fake_num_hosts = 4\n    fake_num_cores_per_host = fake_num_cores // fake_num_hosts\n    invocation_count = [0]\n    global_batch_size = 16 * fake_num_cores\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      self.assertEqual(global_batch_size // fake_num_cores, batch_size)\n\n      context = params['context']\n      current_invocation_count = invocation_count[0]\n\n      (current_input_device, invocation_index_in_context, total_invocations,\n       replicas_consumed_by_current_invocation) = (\n           context.current_input_fn_deployment())\n\n      self.assertEqual('/replica:0/task:0/device:CPU:0', current_input_device)\n      self.assertEqual(current_invocation_count, invocation_index_in_context)\n      self.assertEqual(fake_num_hosts, total_invocations)\n      self.assertEqual(current_invocation_count, context.current_host)\n      self.assertEqual(fake_num_cores_per_host,\n                       replicas_consumed_by_current_invocation)\n\n      # Use the invocation_count to track the number of invocations.\n      invocation_count[0] = current_invocation_count + 1\n\n      return dummy_input_fn_with_dataset(batch_size)\n\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n\n      run_config = create_run_config(\n          iterations_per_loop=4,\n          num_shards=fake_num_cores,\n          per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2\n      )\n\n      est = tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(),\n          config=run_config,\n          train_batch_size=global_batch_size)\n      # This exception is ok as we do not have sufficient TPU cores to run the\n      # model. as far as the assert after it is correct, input pipeline checking\n      # is done and successful.\n      with self.assertRaisesRegex(errors.InvalidArgumentError,\n                                  'there are only 2 cores in the TPU topology'):\n        est.train(_input_fn, steps=4)\n\n    self.assertEqual(fake_num_hosts, invocation_count[0])\n\n  def test_input_deployment_for_per_host_v2_with_model_parallelism(self):\n    fake_num_cores = 32\n    fake_num_hosts = 4\n    fake_num_cores_per_host = fake_num_cores // fake_num_hosts\n    num_cores_per_replica = 2\n    fake_num_replicas = fake_num_cores // num_cores_per_replica\n    fake_num_replicas_per_host = (\n        fake_num_cores_per_host // num_cores_per_replica)\n\n    invocation_count = [0]\n    global_batch_size = 16 * fake_num_cores\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      self.assertEqual(global_batch_size // fake_num_replicas, batch_size)\n\n      context = params['context']\n      current_invocation_count = invocation_count[0]\n\n      (current_input_device, invocation_index_in_context, total_invocations,\n       replicas_consumed_by_current_invocation) = (\n           context.current_input_fn_deployment())\n\n      self.assertEqual('/replica:0/task:0/device:CPU:0', current_input_device)\n      self.assertEqual(current_invocation_count, invocation_index_in_context)\n      self.assertEqual(current_invocation_count, context.current_host)\n      self.assertEqual(fake_num_hosts, total_invocations)\n      self.assertEqual(fake_num_replicas_per_host,\n                       replicas_consumed_by_current_invocation)\n\n      # Use the invocation_count to track the number of invocations.\n      invocation_count[0] = current_invocation_count + 1\n\n      return dummy_input_fn_with_dataset(batch_size)\n\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n\n      run_config = create_run_config(\n          iterations_per_loop=4,\n          num_shards=fake_num_replicas,\n          per_host_input_for_training=tpu_config.InputPipelineConfig\n          .PER_HOST_V2,\n          num_cores_per_replica=num_cores_per_replica)\n      est = tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(),\n          config=run_config,\n          train_batch_size=global_batch_size)\n      # This exception is ok as we do not have sufficient TPU cores to run the\n      # model. as far as the assert after it is correct, input pipeline checking\n      # is done and successful.\n      with self.assertRaisesRegex(errors.InvalidArgumentError,\n                                  'there are only 2 cores in the TPU topology'):\n        est.train(_input_fn, steps=4)\n\n    self.assertEqual(fake_num_hosts, invocation_count[0])\n\n  def test_input_deployment_model_parallelism_cross_host_replica(self):\n    fake_num_cores = 32\n    fake_num_hosts = 4\n    fake_num_cores_per_host = fake_num_cores // fake_num_hosts\n    num_cores_per_replica = 16\n    self.assertGreater(num_cores_per_replica, fake_num_cores_per_host)\n\n    fake_num_replicas = fake_num_cores // num_cores_per_replica\n\n    host_ids = []\n    invocation_count = [0]\n    global_batch_size = 16 * fake_num_cores\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      self.assertEqual(global_batch_size // fake_num_replicas, batch_size)\n\n      context = params['context']\n      current_invocation_count = invocation_count[0]\n\n      (current_input_device, invocation_index_in_context, total_invocations,\n       replicas_consumed_by_current_invocation) = (\n           context.current_input_fn_deployment())\n\n      self.assertEqual('/replica:0/task:0/device:CPU:0', current_input_device)\n      self.assertEqual(current_invocation_count, invocation_index_in_context)\n      host_ids.append(context.current_host)\n      self.assertEqual(fake_num_replicas, total_invocations)\n      self.assertEqual(1, replicas_consumed_by_current_invocation)\n\n      # Use the invocation_count to track the number of invocations.\n      invocation_count[0] = current_invocation_count + 1\n\n      return dummy_input_fn_with_dataset(batch_size)\n\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n\n      run_config = create_run_config(\n          iterations_per_loop=4,\n          num_shards=fake_num_replicas,\n          per_host_input_for_training=tpu_config.InputPipelineConfig\n          .PER_HOST_V2,\n          num_cores_per_replica=num_cores_per_replica)\n      est = tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(),\n          config=run_config,\n          train_batch_size=global_batch_size)\n      # This exception is ok as we do not have sufficient TPU cores to run the\n      # model. as far as the assert after it is correct, input pipeline checking\n      # is done and successful.\n      with self.assertRaisesRegex(errors.InvalidArgumentError,\n                                  'there are only 2 cores in the TPU topology'):\n        est.train(_input_fn, steps=4)\n\n    self.assertEqual(fake_num_replicas, invocation_count[0])\n    self.assertEqual([0, 2], host_ids)\n\n  def test_input_deployment_for_broadcast_mode(self):\n    invocation_count = [0]\n    global_batch_size = 16\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      self.assertEqual(global_batch_size, batch_size)\n\n      context = params['context']\n      current_invocation_count = invocation_count[0]\n\n      (current_input_device, invocation_index_in_context, total_invocations,\n       replicas_consumed_by_current_invocation) = (\n           context.current_input_fn_deployment())\n\n      self.assertEqual('/replica:0/task:0/device:CPU:0', current_input_device)\n      self.assertEqual(current_invocation_count, invocation_index_in_context)\n      self.assertEqual(1, total_invocations)\n      self.assertEqual(FLAGS.test_num_shards,\n                       replicas_consumed_by_current_invocation)\n\n      # Use the invocation_count to track the number of invocations.\n      invocation_count[0] = current_invocation_count + 1\n\n      return dummy_input_fn_with_dataset(batch_size)\n\n    run_config = create_run_config(\n        iterations_per_loop=4,\n        per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(),\n        config=run_config,\n        train_batch_size=global_batch_size)\n    est.train(_input_fn, steps=4)\n    self.assertEqual(1, invocation_count[0])\n\n  def test_input_deployment_for_eval_broadcast_mode(self):\n    invocation_count = [0]\n    global_batch_size = 16\n    num_cores = FLAGS.test_num_shards\n\n    def _input_fn(params, is_training=True):\n      batch_size = params['batch_size']\n      self.assertEqual(global_batch_size, batch_size)\n\n      context = params['context']\n      current_invocation_count = invocation_count[0]\n\n      (current_input_device, invocation_index_in_context, total_invocations,\n       replicas_consumed_by_current_invocation) = (\n           context.current_input_fn_deployment())\n\n      self.assertEqual('/replica:0/task:0/device:CPU:0', current_input_device)\n      if is_training:\n        self.assertEqual(current_invocation_count, invocation_index_in_context)\n      else:\n        self.assertEqual(current_invocation_count - 1,\n                         invocation_index_in_context)\n      self.assertEqual(1, total_invocations)\n      self.assertEqual(num_cores, replicas_consumed_by_current_invocation)\n\n      # Use the invocation_count to track the number of invocations.\n      invocation_count[0] = current_invocation_count + 1\n\n      return dummy_input_fn_with_dataset(batch_size)\n\n    run_config = create_run_config(\n        iterations_per_loop=4,\n        per_host_input_for_training=True,\n        eval_training_input_configuration=tpu_config.InputPipelineConfig.SLICED)\n\n    def _assert_model_fn(features, labels, mode, params):\n      actual_model_fn = get_model_fn()\n      per_replica_batch_size = params['batch_size']\n      self.assertEqual(per_replica_batch_size,\n                       global_batch_size // num_cores)\n      return actual_model_fn(features, labels, mode, params)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_assert_model_fn,\n        config=run_config,\n        train_batch_size=global_batch_size,\n        eval_batch_size=global_batch_size)\n    est.train(functools.partial(_input_fn, is_training=True), steps=1)\n    self.assertEqual(1, invocation_count[0])\n    est.evaluate(functools.partial(_input_fn, is_training=False), steps=1)\n    self.assertEqual(2, invocation_count[0])\n\n  def test_input_deployment_for_per_core(self):\n    fake_num_cores = 32\n    fake_num_hosts = 4\n    fake_num_cores_per_host = fake_num_cores // fake_num_hosts\n    invocation_count = [0]\n    global_batch_size = 16 * fake_num_cores\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      self.assertEqual(global_batch_size // fake_num_cores, batch_size)\n\n      context = params['context']\n      current_invocation_count = invocation_count[0]\n\n      (current_input_device, invocation_index_in_context, total_invocations,\n       replicas_consumed_by_current_invocation) = (\n           context.current_input_fn_deployment())\n\n      self.assertEqual('/replica:0/task:0/device:CPU:0', current_input_device)\n      self.assertEqual(current_invocation_count, invocation_index_in_context)\n      self.assertEqual(current_invocation_count // fake_num_cores_per_host,\n                       context.current_host)\n      self.assertEqual(fake_num_cores, total_invocations)\n      self.assertEqual(1, replicas_consumed_by_current_invocation)\n\n      # Use the invocation_count to track the number of invocations.\n      invocation_count[0] = current_invocation_count + 1\n\n      return dummy_input_fn(batch_size)\n\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n\n      run_config = create_run_config(\n          iterations_per_loop=4,\n          num_shards=fake_num_cores,\n          per_host_input_for_training=False)\n\n      est = tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(),\n          config=run_config,\n          train_batch_size=global_batch_size)\n      # This exception is ok as we do not have sufficient TPU cores to run the\n      # model. as far as the assert after it is correct, input pipeline checking\n      # is done and successful.\n      with self.assertRaisesRegex(errors.InvalidArgumentError,\n                                  'there are only 2 cores in the TPU topology'):\n        est.train(_input_fn, steps=4)\n\n    self.assertEqual(fake_num_cores, invocation_count[0])\n\n  def test_hparams_as_params(self):\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      context = params['context']\n      self.assertEqual(FLAGS.test_num_shards, context.num_replicas)\n      return dummy_input_fn(batch_size)\n\n    run_config = create_run_config(\n        iterations_per_loop=4, per_host_input_for_training=False)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(),\n        params={},\n        config=run_config,\n        train_batch_size=16)\n    est.train(_input_fn, steps=4)\n\n\nclass TPUEstimatorInputFnTest(parameterized.TestCase):\n\n  def setUp(self):\n    # TODO(b/65703635): Remove setting/restoring the constant here.\n    # As we are transitioning from deprecated mode to new mode. We have to\n    # test both cases to ensure we do not break clients.\n    super(TPUEstimatorInputFnTest, self).setUp()\n    self._old_value = tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP\n\n  def tearDown(self):\n    super(TPUEstimatorInputFnTest, self).tearDown()\n    tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP = self._old_value\n\n  # Use 10 to test TPUEstimator is correctly concatenating small tensors.\n  @parameterized.parameters(1, 10)\n  def test_succeed_with_dataset(self, num_features):\n    tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP = True\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      x = np.random.normal(size=[batch_size, 1]).astype(np.float32)\n      x1 = np.random.normal(size=[batch_size, 1]).astype(np.int32)\n      labels = [[2.0]] * batch_size\n\n      dataset1 = dataset_lib.Dataset.from_tensor_slices(x)\n      dataset2 = dataset_lib.Dataset.from_tensor_slices(x1)\n      dataset3 = dataset_lib.Dataset.from_tensor_slices(labels)\n      dataset = dataset_lib.Dataset.zip((dataset1, dataset2, dataset3))\n\n      def _map_fn(x, x1, y):\n        xs = {}\n        for i in range(num_features):\n          xs['x' * (i + 1)] = array_ops.identity(x)\n          xs['x1' * (i + 1)] = array_ops.identity(x1)\n        return xs, y\n\n      dataset = dataset.map(_map_fn)\n      dataset = dataset.repeat()\n      dataset = dataset.batch(batch_size, drop_remainder=True)\n      return dataset\n\n    run_config = create_run_config(\n        iterations_per_loop=4, per_host_input_for_training=True)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=16)\n    est.train(_input_fn, steps=4)\n\n  def test_succeed_with_input_return_features_and_labels_with_dataset(self):\n    tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP = True\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      return dummy_input_fn(batch_size)\n\n    run_config = create_run_config(\n        iterations_per_loop=4, per_host_input_for_training=False)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=16)\n    est.train(_input_fn, steps=4)\n\n  def test_fail_with_queue_based_input_fn_in_while_loop(self):\n    tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP = True\n\n    data = np.arange(40, dtype=np.float32).reshape(40, 1)\n    x = {'x': data}\n    y = data * 2.0\n\n    def input_fn(params):\n      batch_size = params['batch_size']\n      return numpy_io.numpy_input_fn(\n          x, y, batch_size=batch_size, shuffle=False, num_epochs=None)()\n\n    run_config = create_run_config(\n        iterations_per_loop=4, per_host_input_for_training=False)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=16)\n\n    with self.assertRaisesRegex(RuntimeError,\n                                _INPUT_PIPELINE_WITH_QUEUE_RUNNER):\n      est.train(input_fn, steps=4)\n\n  def test_warning_with_queue_based_input_fn(self):\n    tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP = False\n\n    data = np.arange(40, dtype=np.float32).reshape(40, 1)\n    x = {'x': data}\n    y = data * 2.0\n\n    def input_fn(params):\n      batch_size = params['batch_size']\n      return numpy_io.numpy_input_fn(\n          x, y, batch_size=batch_size, shuffle=False, num_epochs=None)()\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=16)\n\n    with test.mock.patch.object(logging, 'warn') as mock_log:\n      est.train(input_fn, steps=4)\n      self.assertRegex(\n          str(mock_log.call_args), _INPUT_PIPELINE_WITH_QUEUE_RUNNER)\n\n  def test_nested_inputs_dict(self):\n    self.help_test_nested_inputs(nest_type='dict')\n\n  def test_nested_inputs_tuple(self):\n    self.help_test_nested_inputs(nest_type='tuple')\n\n  def test_nested_inputs_namedtuple(self):\n    self.help_test_nested_inputs(nest_type='namedtuple')\n\n  def help_test_nested_inputs(self, nest_type):\n    self.assertIn(nest_type, ['dict', 'tuple', 'namedtuple'])\n    tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP = True\n\n    class MyTuple(collections.namedtuple('MyTuple', ['a', 'b'])):\n      pass\n\n    def model_fn(features, labels, mode, params):\n      del params\n      if nest_type == 'dict':\n        inputs = features['x']\n      elif nest_type == 'tuple':\n        inputs = features\n      elif nest_type == 'namedtuple':\n        inputs = tuple(features)\n      else:\n        inputs = features\n      predictions = layers.dense(\n          inputs[0], 1, kernel_initializer=init_ops.zeros_initializer())\n      loss = losses.mean_squared_error(labels, predictions)\n      export_outputs = None\n      optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n          training.GradientDescentOptimizer(learning_rate=0.5))\n      train_op = optimizer.minimize(loss, training.get_global_step())\n\n      return tpu_estimator.TPUEstimatorSpec(\n          mode, loss=loss, train_op=train_op, export_outputs=export_outputs)\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n\n      x = dataset_ops.Dataset.from_tensor_slices(\n          (random_uniform([4, 1]),\n           random_uniform([4, 1], maxval=100, dtype=dtypes.float32)))\n      dataset_labels = dataset_ops.Dataset.from_tensor_slices(\n          random_uniform([4, 1]))\n      dataset = dataset_ops.Dataset.zip((x, dataset_labels))\n\n      def _map_fn(x, y):\n        if nest_type == 'dict':\n          return {'x': x}, y\n        elif nest_type == 'tuple':\n          return tuple(x), y\n        elif nest_type == 'namedtuple':\n          return MyTuple(*x), y\n        else:\n          return x, y\n\n      dataset = dataset.map(_map_fn)\n      dataset = dataset.batch(batch_size, drop_remainder=True)\n      dataset = dataset.repeat(-1)\n      return dataset\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn, config=run_config, train_batch_size=4)\n    est.train(_input_fn, steps=4)\n\n\nclass _DummyHook(session_run_hook.SessionRunHook):\n  \"\"\"Check whether this hook is called or not.\"\"\"\n\n  def __init__(self):\n    \"\"\"Constructs the run hook.\"\"\"\n    self._called = False\n\n  def after_create_session(self, sees, coord):\n    del sees, coord\n    self._called = True\n\n  @property\n  def called(self):\n    return self._called\n\n\nclass TPUEstimatorModelFnTest(test.TestCase):\n\n  def test_succeed_with_missing_labels(self):\n\n    def _model_fn(features, mode, params):\n      labels = features.pop('y')\n      return get_model_fn()(features, labels, mode, params)\n\n    def _input_fn_without_labels(params):\n      batch_size = params['batch_size']\n      features, labels = dummy_input_fn(batch_size)\n      return {'x': features['x'], 'y': labels}\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn, config=run_config, train_batch_size=16)\n    est.train(_input_fn_without_labels, steps=1)\n\n  def test_succeed_with_log_step_count_steps_none(self):\n\n    def _model_fn(features, mode, params):\n      labels = features.pop('y')\n      return get_model_fn()(features, labels, mode, params)\n\n    def _input_fn_without_labels(params):\n      batch_size = params['batch_size']\n      features, labels = dummy_input_fn(batch_size)\n      return {'x': features['x'], 'y': labels}\n\n    run_config = create_run_config(iterations_per_loop=4)\n    run_config = run_config.replace(log_step_count_steps=None)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn, config=run_config, train_batch_size=16)\n    est.train(_input_fn_without_labels, steps=1)\n\n  def test_missing_labels_in_model_fn_not_input_fn(self):\n\n    def _model_fn(features, mode, params):\n      del features, mode, params  # unused.\n      return tpu_estimator.TPUEstimatorSpec()\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn, config=run_config, train_batch_size=16)\n\n    with self.assertRaisesRegex(\n        ValueError,\n        'model_fn does not take labels, but input_fn returns labels'):\n      est.train(_input_fn, steps=1)\n\n  def test_missing_params(self):\n\n    def _model_fn(features, labels, mode):\n      del features, labels, mode  # unused.\n      return tpu_estimator.TPUEstimatorSpec()\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn, config=run_config, train_batch_size=16)\n\n    with self.assertRaisesRegex(ValueError,\n                                'model_fn .* does not include params'):\n      est.train(_input_fn, steps=1)\n\n  def test_invalid_arg(self):\n\n    def _model_fn(features, labels, invalid_arg):\n      del features, labels, invalid_arg  # unused.\n      return tpu_estimator.TPUEstimatorSpec()\n\n    with self.assertRaisesRegex(ValueError,\n                                'model_fn .* has following not expected args'):\n      run_config = create_run_config(iterations_per_loop=4)\n      tpu_estimator.TPUEstimator(\n          model_fn=_model_fn, config=run_config, train_batch_size=16)\n\n  def test_valid_training_hook(self):\n    run_config = create_run_config(iterations_per_loop=4)\n    dummy_hook = _DummyHook()\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    def _model_fn(features, labels, mode, params):\n      spec = get_model_fn()(features, labels, mode, params)\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          train_op=spec.train_op,\n          loss=spec.loss,\n          training_hooks=[dummy_hook])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=2 * FLAGS.test_num_shards)\n\n    est.train(input_fn=_input_fn, steps=1)\n    self.assertTrue(dummy_hook.called)\n\n  def test_valid_eval_hook(self):\n    run_config = create_run_config(iterations_per_loop=4)\n    dummy_hook = _DummyHook()\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    def _model_fn(features, labels, mode, params):\n      spec = get_model_fn()(features, labels, mode, params)\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode,\n          train_op=spec.train_op,\n          loss=spec.loss,\n          evaluation_hooks=[dummy_hook])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=2 * FLAGS.test_num_shards,\n        eval_batch_size=2 * FLAGS.test_num_shards)\n\n    est.evaluate(input_fn=_input_fn, steps=1)\n    self.assertTrue(dummy_hook.called)\n\n  def test_valid_prediction_hook(self):\n    run_config = create_run_config(iterations_per_loop=4)\n    dummy_hook = _DummyHook()\n\n    def _input_fn(params):\n      return dummy_input_fn_with_dataset(params['batch_size'], repeat=False)\n\n    def _model_fn(features, labels, mode, params):\n      del labels, params\n      predictions = dense_computation(features)\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode,\n          train_op=None,\n          loss=None,\n          predictions={'predictions': predictions},\n          prediction_hooks=[dummy_hook])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=2 * FLAGS.test_num_shards,\n        predict_batch_size=2 * FLAGS.test_num_shards)\n\n    list(est.predict(input_fn=_input_fn))\n    self.assertTrue(dummy_hook.called)\n\n  def test_invalid_training_chief_hook(self):\n    run_config = create_run_config(iterations_per_loop=4)\n    dummy_hook = session_run_hook.SessionRunHook()\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    def _model_fn(features, labels, mode, params):\n      spec = get_model_fn()(features, labels, mode, params)\n      return model_fn_lib.EstimatorSpec(\n          mode=mode,\n          train_op=spec.train_op,\n          loss=spec.loss,\n          training_chief_hooks=[dummy_hook])\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=2 * FLAGS.test_num_shards)\n\n    with self.assertRaisesRegex(\n        ValueError, 'training_chief_hooks returned by '\n        'EstimatorSpec is not supported in '\n        'TPUEstimator'):\n      est.train(input_fn=_input_fn, steps=1)\n\n  def test_access_device_assignment_in_model_fn(self):\n\n    def _model_fn(features, labels, mode, params):\n      ctx = params['context']\n      self.assertIsInstance(ctx.device_assignment,\n                            tf.tpu.experimental.DeviceAssignment)\n      return get_model_fn()(features, labels, mode, params)\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    FLAGS.test_num_shards //= 2\n    run_config = create_run_config(\n        iterations_per_loop=4, num_cores_per_replica=2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn, config=run_config, train_batch_size=16)\n    est.train(_input_fn, steps=4)\n    FLAGS.test_num_shards *= 2\n\n  def test_fail_to_call_deployment_in_model_fn(self):\n\n    def _model_fn(features, labels, mode, params):\n      ctx = params['context']\n      with self.assertRaisesRegex(\n          RuntimeError, 'This TPUContext instance must not be '\n          'called from model_fn.'):\n        ctx.current_input_fn_deployment()\n      return get_model_fn()(features, labels, mode, params)\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn, config=run_config, train_batch_size=16)\n    est.train(_input_fn, steps=4)\n\n\nclass TPUEstimatorPredictionTest(test.TestCase):\n\n  def _test_train_and_predict(self, run_config, dataset_size,\n                              input_tensor=None):\n    \"\"\"Trains the model and returns the list of global steps after each loop.\"\"\"\n\n    def train_input_fn(params):\n      return dummy_input_fn_with_dataset(\n          dataset_size,\n          repeat=True,\n          x=input_tensor,\n          batch_size=params['batch_size'])\n\n    def predict_input_fn(params):\n      return dummy_input_fn_with_dataset(\n          dataset_size,\n          repeat=False,\n          x=input_tensor,\n          batch_size=params['batch_size'])\n\n    def _model_fn(features, labels, mode, params):\n      return get_model_fn()(features, labels, mode, params)\n\n    batch_size = 16\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=batch_size,\n        eval_batch_size=batch_size,\n        predict_batch_size=batch_size)\n\n    est.train(train_input_fn, steps=1)\n    predictions = list(est.predict(predict_input_fn))\n    if (run_config.tpu_config.per_host_input_for_training == tpu_config\n        .InputPipelineConfig.BROADCAST):\n      expected_size = batch_size\n    elif (run_config.tpu_config.per_host_input_for_training == tpu_config\n          .InputPipelineConfig.PER_HOST_V2):\n      expected_size = dataset_size\n    else:\n      expected_size = batch_size\n\n    self.assertEqual(expected_size, len(predictions))\n\n  def _construct_run_config(self,\n                            mode,\n                            num_shards=2,\n                            input_partition_dims=None,\n                            num_cores_per_replica=None):\n\n    return create_run_config(\n        iterations_per_loop=4,\n        num_shards=num_shards,\n        per_host_input_for_training=mode,\n        input_partition_dims=input_partition_dims,\n        num_cores_per_replica=num_cores_per_replica)\n\n  def test_train_and_predict_per_host_v1(self):\n    self._test_train_and_predict(\n        self._construct_run_config(tpu_config.InputPipelineConfig.PER_HOST_V1),\n        16)\n\n  def test_train_and_predict_per_host_v2_evenly_distributed(self):\n    self._test_train_and_predict(\n        self._construct_run_config(tpu_config.InputPipelineConfig.PER_HOST_V2),\n        16)\n\n  def test_train_and_predict_per_host_v2_not_evenly_distributed(self):\n    self._test_train_and_predict(\n        self._construct_run_config(tpu_config.InputPipelineConfig.PER_HOST_V2),\n        24)\n\n  def test_train_and_predict_with_input_partition(self):\n    self._test_train_and_predict(\n        self._construct_run_config(\n            tpu_config.InputPipelineConfig.PER_HOST_V2,\n            num_shards=1,\n            input_partition_dims=[{\n                'x': [1, 2, 1, 1]\n            }, None],\n            num_cores_per_replica=2), 16,\n        np.zeros((16, 32, 32, 3), dtype=np.float32))\n\n  def test_train_and_predict_broadcast(self):\n    self._test_train_and_predict(\n        self._construct_run_config(tpu_config.InputPipelineConfig.BROADCAST),\n        16)\n\n  def test_non_static_shape(self):\n\n    def predict_input_fn(params):\n      return dummy_input_fn_with_dataset(params['batch_size'], repeat=False)\n\n    def _model_fn(features, labels, mode, params):\n      spec = get_model_fn()(features, labels, mode, params)\n      spec.predictions['dummy'] = array_ops.placeholder(\n          dtypes.float32, shape=(None, 24))\n      return spec\n\n    batch_size = 16\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=batch_size,\n        eval_batch_size=batch_size,\n        predict_batch_size=batch_size)\n\n    with self.assertRaisesRegex(ValueError, 'should be static'):\n      list(est.predict(predict_input_fn))\n\n  def test_predict_on_cpu(self):\n    \"\"\"Trains the model and returns the list of global steps after each loop.\"\"\"\n\n    def train_input_fn(params):\n      return dummy_input_fn_with_dataset(params['batch_size'], repeat=True)\n\n    def predict_input_fn(params):\n      # A fixed input\n      x = np.linspace(\n          0.0, 100.0, num=batch_size).reshape(batch_size, 1).astype(np.float32)\n\n      return dummy_input_fn_with_dataset(\n          params['batch_size'], repeat=False, x=x)\n\n    def _model_fn(features, labels, mode, params):\n      return get_model_fn()(features, labels, mode, params)\n\n    batch_size = 16\n    run_config = create_run_config(iterations_per_loop=4)\n    tpu_est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=batch_size,\n        eval_batch_size=batch_size,\n        predict_batch_size=batch_size,\n        use_tpu=True)\n\n    tpu_est.train(train_input_fn, steps=1)\n    tpu_predictions = [\n        x['predictions'] for x in tpu_est.predict(predict_input_fn)\n    ]\n    self.assertEqual(batch_size * 1, len(tpu_predictions))\n\n    cpu_est = tpu_estimator.TPUEstimator(\n        model_dir=tpu_est.model_dir,  # To load the ckpt.\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=batch_size,\n        eval_batch_size=batch_size,\n        predict_batch_size=batch_size,\n        use_tpu=False)\n    cpu_predictions = [\n        x['predictions'] for x in cpu_est.predict(predict_input_fn)\n    ]\n    self.assertEqual(batch_size * 1, len(cpu_predictions))\n\n    self.assertAllClose(tpu_predictions, cpu_predictions, atol=0.01)\n\n  def test_train_and_export(self):\n\n    def train_input_fn(params):\n      return dummy_input_fn_with_dataset(params['batch_size'], repeat=True)\n\n    def _model_fn(features, labels, mode, params):\n      return get_model_fn()(features, labels, mode, params)\n\n    batch_size = 16\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=batch_size,\n        eval_batch_size=batch_size,\n        predict_batch_size=1)\n\n    est.train(train_input_fn, steps=1)\n\n    # Even though the predict_batch_size is 1, not divisible by the num_shards\n    # (2 in this case), export_savedmodel should not trigger the TPU validation.\n    # This test ensures that the predict mode is handled correctly inside\n    # TPUEstimator.\n    feature_spec = {'x': parsing_ops.FixedLenFeature([1], dtypes.float32)}\n    serving_input_receiver_fn = (\n        export.build_parsing_serving_input_receiver_fn(feature_spec))\n    est.export_saved_model(\n        tempfile.mkdtemp(dir=self.get_temp_dir()), serving_input_receiver_fn)\n\n\nclass TPUEstimatorTrainingTest(test.TestCase):\n\n  def _train_and_return_global_steps(self,\n                                     iterations_per_loop,\n                                     steps=None,\n                                     max_steps=None,\n                                     pre_train_steps=None):\n    \"\"\"Trains the model and returns the list of global steps after each loop.\"\"\"\n\n    def input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    def _model_fn(features, labels, mode, params):\n      return get_model_fn()(features, labels, mode, params)\n\n    run_config = create_run_config(iterations_per_loop=iterations_per_loop)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n\n    class _TrainStepCheckHook(session_run_hook.SessionRunHook):\n      \"\"\"Check eval step counter after one session.run.\"\"\"\n\n      def __init__(self):\n        \"\"\"Constructs the run hook.\"\"\"\n        self._global_steps = []\n\n      @property\n      def global_steps(self):\n        return self._global_steps\n\n      def after_run(self, run_context, run_values):\n        global_step = run_context.session.run(training.get_global_step())\n        self._global_steps.append(global_step)\n\n    if pre_train_steps:\n      est.train(input_fn, steps=pre_train_steps)\n\n    hook = _TrainStepCheckHook()\n    est.train(input_fn, steps=steps, max_steps=max_steps, hooks=[hook])\n    return hook.global_steps\n\n  def test_train_steps_not_divisible_by_iterations(self):\n    # From scratch.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=4, steps=10)\n    self.assertEqual([4, 8, 10], global_steps_per_loop)\n\n    # From existing checkpoint.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=4, steps=10, pre_train_steps=3)\n    self.assertEqual([7, 11, 13], global_steps_per_loop)\n\n  def test_train_steps_divisible_by_iterations(self):\n    # From scratch.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=4, steps=12)\n    self.assertEqual([4, 8, 12], global_steps_per_loop)\n\n    # From existing checkpoint.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=4, steps=12, pre_train_steps=3)\n    self.assertEqual([7, 11, 15], global_steps_per_loop)\n\n  def test_train_steps_with_large_iterations(self):\n    # From scratch.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=40, steps=12)\n    self.assertEqual([12], global_steps_per_loop)\n\n    # From existing checkpoint.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=40, steps=12, pre_train_steps=3)\n    self.assertEqual([15], global_steps_per_loop)\n\n  def test_train_max_steps_not_divisible_by_iterations(self):\n    # From scratch.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=4, max_steps=10)\n    self.assertEqual([4, 8, 10], global_steps_per_loop)\n\n    # From existing checkpoint.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=4, max_steps=10, pre_train_steps=3)\n    self.assertEqual([7, 10], global_steps_per_loop)\n\n  def test_train_max_steps_divisible_by_iterations(self):\n    # From scratch.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=4, max_steps=12)\n    self.assertEqual([4, 8, 12], global_steps_per_loop)\n\n    # From existing checkpoint.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=4, max_steps=15, pre_train_steps=3)\n    self.assertEqual([7, 11, 15], global_steps_per_loop)\n\n  def test_train_max_steps_with_large_iterations(self):\n    # From scratch.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=40, max_steps=12)\n    self.assertEqual([12], global_steps_per_loop)\n\n    # From existing checkpoint.\n    global_steps_per_loop = self._train_and_return_global_steps(\n        iterations_per_loop=40, max_steps=12, pre_train_steps=3)\n    self.assertEqual([12], global_steps_per_loop)\n\n  def test_error_out_if_train_steps_is_float(self):\n    with self.assertRaisesRegex(TypeError, 'must be int'):\n      self._train_and_return_global_steps(iterations_per_loop=40, steps=12.3)\n\n  def test_error_out_if_train_steps_is_invalid(self):\n    with self.assertRaisesRegex(ValueError, 'Must specify.*> 0'):\n      self._train_and_return_global_steps(iterations_per_loop=40, steps=-32)\n\n  def test_error_out_if_train_max_steps_is_float(self):\n    with self.assertRaisesRegex(TypeError, 'must be int'):\n      self._train_and_return_global_steps(\n          iterations_per_loop=40, max_steps=12.3)\n\n  def test_error_out_if_train_max_steps_is_invalid(self):\n    with self.assertRaisesRegex(ValueError, 'Must specify.*> 0'):\n      self._train_and_return_global_steps(iterations_per_loop=40, max_steps=-32)\n\n  def test_warm_starts(self):\n\n    def _make_model_fn(x, use_tpu):\n\n      def _variable_creating_model_fn(features, labels, mode, params):\n        del params\n        loss = None\n        train_op = None\n        variable_scope.get_variable('x', initializer=x)\n        predictions = dense_computation(features)\n        loss = losses.mean_squared_error(labels, predictions)\n        optimizer = training.GradientDescentOptimizer(learning_rate=0.5)\n        if use_tpu:\n          optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)\n        train_op = optimizer.minimize(loss, training.get_global_step())\n        if use_tpu:\n          return tpu_estimator.TPUEstimatorSpec(\n              mode, loss=loss, train_op=train_op)\n        else:\n          return model_fn_lib.EstimatorSpec(\n              mode,\n              loss=constant_op.constant(1.),\n              train_op=state_ops.assign_add(training.get_global_step(), 1))\n\n      return _variable_creating_model_fn\n\n    def input_fn(params):\n      return dummy_input_fn(params.get('batch_size', 16))\n\n    run_config = create_run_config(iterations_per_loop=1)\n    tpu_est = tpu_estimator.TPUEstimator(\n        model_fn=_make_model_fn(42., use_tpu=True),\n        config=run_config,\n        train_batch_size=16,\n        eval_batch_size=16)\n    tpu_est.train(input_fn, steps=10)\n\n    warm_started_est = estimator_lib.Estimator(\n        model_fn=_make_model_fn(36., use_tpu=False),\n        warm_start_from=tpu_est.model_dir)\n    warm_started_est.train(input_fn, steps=5)\n    # warm_start is called after the model_fn, so x should have the value\n    # from the checkpoint.\n    self.assertEqual(42., warm_started_est.get_variable_value('x'))\n\n\nclass TPUEstimatorValidationTest(parameterized.TestCase, test.TestCase):\n\n  def _query_system(self, master_address, cluster_def, query_topology):\n    del master_address, cluster_def, query_topology\n    return tpu_system_metadata_lib.TPUSystemMetadata(\n        num_cores=16,\n        num_hosts=2,\n        num_of_cores_per_host=8,\n        topology=None,\n        devices=[])\n\n  def test_error_if_cross_replica_sum_missing(self):\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    def _model_fn(features, labels, params):\n      del params\n      predictions = layers.dense(\n          features['x'], 1, kernel_initializer=init_ops.zeros_initializer())\n      loss = losses.mean_squared_error(labels, predictions)\n      optimizer = training.GradientDescentOptimizer(learning_rate=0.5)\n      train_op = optimizer.minimize(loss, training.get_global_step())\n\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=None, loss=loss, train_op=train_op)\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn, train_batch_size=8, config=run_config, params={})\n\n    with self.assertRaisesRegex(ValueError, 'model training on TPUs'):\n      est.train(input_fn=_input_fn, steps=1)\n\n  def test_no_error_if_cross_replica_sum_present(self):\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(),\n        train_batch_size=8,\n        config=run_config,\n        params={})\n    est.train(input_fn=_input_fn, steps=1)\n\n  def test_error_dynamic_shape_tensor_features_for_model(self):\n    \"\"\"Asserting that features Tensor to TPUEstimator model has static shape.\n\n    \"\"\"\n\n    def _input_fn(params):\n      features = reshape(\n          math_ops.range(params['batch_size'] * 64, dtype=dtypes.float32),\n          (params['batch_size'], 64))\n      # Make features with dynamic shape by the help of random padding.\n      padding = random_uniform([], minval=0, maxval=10, dtype=dtypes.int32)\n      features = array_ops.pad(features, [(0, 0), (0, padding)])\n      return dataset_lib.Dataset.from_tensor_slices(\n          (features, math_ops.range(params['batch_size']) % 10)).repeat().batch(\n              16, drop_remainder=True)\n\n    def _model_fn(features, labels, mode, params):\n      del labels\n      del params\n      if mode == _PREDICT:\n        return tpu_estimator.TPUEstimatorSpec(\n            mode=mode, predictions={'value': features})\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        train_batch_size=8,\n        config=run_config,\n        predict_batch_size=16,\n        params={})\n    with self.assertRaisesRegex(ValueError, 'features.*must have static'):\n      list(est.predict(_input_fn))\n\n  def test_error_dynamic_shape_dict_tensor_features_for_model(self):\n    \"\"\"Asserting that features dict to TPUEstimator model has static shape.\n\n    \"\"\"\n\n    def _input_fn_dict(params):\n      features = reshape(\n          math_ops.range(params['batch_size'] * 64, dtype=dtypes.float32),\n          (params['batch_size'], 64))\n      # Make features with dynamic shape by the help of random padding.\n      padding = random_uniform([], minval=0, maxval=10, dtype=dtypes.int32)\n      features = array_ops.pad(features, [(0, 0), (0, padding)])\n      dataset = dataset_lib.Dataset.from_tensor_slices(features)\n      dataset = dataset.map(lambda v: {'key': v})\n      return dataset.repeat().batch(16, drop_remainder=True)\n\n    def _model_fn(features, labels, mode, params):\n      del labels\n      del params\n      if mode == _PREDICT:\n        return tpu_estimator.TPUEstimatorSpec(\n            mode=mode, predictions={'value': features['key']})\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        train_batch_size=8,\n        config=run_config,\n        predict_batch_size=16,\n        params={})\n    with self.assertRaisesRegex(ValueError, 'features.*must have static.*'):\n      list(est.predict(_input_fn_dict))\n\n  def test_error_dynamic_shape_tensor_labels_for_model(self):\n    \"\"\"Asserting that labels to TPUEstimator model has static shape.\n\n    \"\"\"\n\n    def _input_fn(params):\n      features = reshape(\n          math_ops.range(params['batch_size'] * 64, dtype=dtypes.float32),\n          (params['batch_size'], 64))\n      labels = reshape(\n          math_ops.range(params['batch_size'] * 64, dtype=dtypes.float32),\n          (params['batch_size'], 64))\n      # Make labels with dynamic shape by the help of random padding.\n      padding = random_uniform([], minval=0, maxval=10, dtype=dtypes.int32)\n      labels = array_ops.pad(labels, [(0, 0), (0, padding)])\n      dataset = dataset_lib.Dataset.from_tensor_slices((features, labels))\n      return dataset.repeat().batch(16, drop_remainder=True)\n\n    def _model_fn(features, labels, mode, params):\n      del labels\n      del params\n      if mode == _PREDICT:\n        return tpu_estimator.TPUEstimatorSpec(\n            mode=mode, predictions={'value': features})\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        train_batch_size=8,\n        config=run_config,\n        predict_batch_size=16,\n        params={})\n    with self.assertRaisesRegex(ValueError, 'labels.*must have static'):\n      list(est.predict(_input_fn))\n\n  def test_error_dynamic_shape_dict_tensor_labels_for_model(self):\n    \"\"\"Asserting that labels dict to TPUEstimator model has static shape.\n\n    \"\"\"\n\n    def _input_fn_dict(params):\n      features = reshape(\n          math_ops.range(params['batch_size'] * 64, dtype=dtypes.float32),\n          (params['batch_size'], 64))\n      labels = reshape(\n          math_ops.range(params['batch_size'] * 64, dtype=dtypes.float32),\n          (params['batch_size'], 64))\n      # Make labels with dynamic shape by the help of random padding.\n      padding = random_uniform([], minval=0, maxval=10, dtype=dtypes.int32)\n      labels = array_ops.pad(labels, [(0, 0), (0, padding)])\n      dataset = dataset_lib.Dataset.from_tensor_slices((features, labels))\n      dataset = dataset.map(lambda f, l: ({'fkey': f}, {'lkey': l}))\n      return dataset.repeat().batch(16, drop_remainder=True)\n\n    def _model_fn(features, labels, mode, params):\n      del labels\n      del params\n      if mode == _PREDICT:\n        return tpu_estimator.TPUEstimatorSpec(\n            mode=mode, predictions={'value': features['fkey']})\n\n    run_config = create_run_config(iterations_per_loop=4)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        train_batch_size=8,\n        config=run_config,\n        predict_batch_size=16,\n        params={})\n    with self.assertRaisesRegex(ValueError, 'labels.*must have static*. shape'):\n      list(est.predict(_input_fn_dict))\n\n  def test_error_none_eval_batch_size_for_evaluation_mode(self):\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    with self.assertRaisesRegex(\n        ValueError,\n        'eval_batch_size in TPUEstimator constructor cannot be `None`'):\n      est = tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(),\n          config=create_run_config(iterations_per_loop=4),\n          train_batch_size=64,\n          use_tpu=True)\n      est.evaluate(_input_fn, steps=1)\n\n  def test_error_none_predict_batch_size_for_prediction_mode(self):\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    with self.assertRaisesRegex(\n        ValueError,\n        'predict_batch_size in TPUEstimator constructor cannot be `None`'):\n      est = tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(),\n          config=create_run_config(iterations_per_loop=4),\n          train_batch_size=64,\n          use_tpu=True)\n      list(est.predict(_input_fn))\n\n  @parameterized.parameters(\n      (tpu_config.InputPipelineConfig.PER_HOST_V1, 'evaluate'),\n      (tpu_config.InputPipelineConfig.PER_HOST_V1, 'predict'),\n      (tpu_config.InputPipelineConfig.PER_HOST_V2, 'predict'))\n  def test_error_num_hosts_and_replicas_larger_than_1_in_eval_and_predict_mode(\n      self, input_pipeline_mode, predict_or_evaluate):\n\n    def _input_fn(params):\n      return dummy_input_fn(params['batch_size'])\n\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n\n      run_config = create_run_config(\n          iterations_per_loop=1, num_cores_per_replica=8,\n          per_host_input_for_training=input_pipeline_mode)\n\n      if predict_or_evaluate == 'evaluate':\n        expected_error_re = ('TPUEstimator.evaluate is only supported '\n                             'under three conditions')\n      else:\n        expected_error_re = ('TPUEstimator.predict is only supported '\n                             'under three conditions')\n\n      with self.assertRaisesRegex(ValueError, expected_error_re):\n        est = tpu_estimator.TPUEstimator(\n            model_fn=get_model_fn(),\n            config=run_config,\n            train_batch_size=32,\n            eval_batch_size=32,\n            predict_batch_size=32,\n            use_tpu=True)\n        if predict_or_evaluate == 'evaluate':\n          est.evaluate(_input_fn, steps=1)\n        else:\n          list(est.predict(_input_fn))\n\n  def test_evaluate_1host_and_replicas_larger_than_1_with_PER_HOST_V2(\n      self):\n    fake_num_cores = 32\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      x = np.random.normal(size=[batch_size, 20]).astype(np.float32)\n      return dummy_input_fn_with_dataset(batch_size, repeat=False, x=x)\n\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n\n      run_config = create_run_config(\n          iterations_per_loop=4,\n          num_shards=fake_num_cores // 2,\n          per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2\n      )\n\n      est = tpu_estimator.TPUEstimator(\n          model_fn=get_model_fn(),\n          config=run_config,\n          train_batch_size=32,\n          eval_batch_size=32,\n          predict_batch_size=32,\n          use_tpu=True)\n\n      # This exception is ok as we do not have sufficient TPU cores to run the\n      # model.\n      with self.assertRaisesRegex(errors.InvalidArgumentError,\n                                  'there are only 2 cores in the TPU topology'):\n        est.evaluate(_input_fn, steps=1)\n\n  @parameterized.parameters(\n      (tpu_config.InputPipelineConfig.BROADCAST, 'evaluate'),\n      (tpu_config.InputPipelineConfig.PER_HOST_V1, 'evaluate'),\n      (tpu_config.InputPipelineConfig.PER_HOST_V2, 'evaluate'),\n      (tpu_config.InputPipelineConfig.BROADCAST, 'predict'),\n      (tpu_config.InputPipelineConfig.PER_HOST_V1, 'predict'),\n      (tpu_config.InputPipelineConfig.PER_HOST_V2, 'predict'))\n  def test_no_error_1host_1replica_in_eval_and_predict_mode(\n      self, input_pipeline_mode, predict_or_evaluate):\n\n    def _input_fn(params):\n      return dummy_input_fn_with_dataset(\n          dataset_size=params['batch_size'], repeat=False)\n\n    FLAGS.test_num_shards = None\n    run_config = create_run_config(\n        iterations_per_loop=1, num_cores_per_replica=2,\n        per_host_input_for_training=input_pipeline_mode)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(),\n        config=run_config,\n        train_batch_size=32,\n        eval_batch_size=32,\n        predict_batch_size=32,\n        use_tpu=True)\n    if predict_or_evaluate == 'evaluate':\n      est.evaluate(_input_fn, steps=1)\n    else:\n      list(est.predict(_input_fn))\n\n\nclass TPUConfigTest(test.TestCase):\n\n  def _create_ctx(self, run_config, mode=_TRAIN):\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(), config=run_config, train_batch_size=16)\n\n    with est._ctx.with_mode(mode) as ctx:\n      return ctx\n\n  def test_no_cluster_spec(self):\n    run_config = tpu_config.RunConfig()\n    ctx = self._create_ctx(run_config)\n    self.assertIsNone(ctx.master_job)\n    ctx = self._create_ctx(run_config, mode=_EVAL)\n    self.assertIsNone(ctx.master_job)\n\n    run_config = tpu_config.RunConfig(master='grpc://10.4.5.7:8470')\n    ctx = self._create_ctx(run_config)\n    self.assertEqual('tpu_worker', ctx.master_job)\n    ctx = self._create_ctx(run_config, mode=_EVAL)\n    self.assertEqual('tpu_worker', ctx.master_job)\n\n    run_config = tpu_config.RunConfig(\n        master='grpc://10.4.5.7:8470', evaluation_master='grpc://10.5.6.7:8470')\n    ctx = self._create_ctx(run_config)\n    self.assertEqual('tpu_worker', ctx.master_job)\n    ctx = self._create_ctx(run_config, mode=_EVAL)\n    self.assertEqual('tpu_worker', ctx.master_job)\n\n  def test_cluster_spec_prop(self):\n    cluster_def = cluster_pb2.ClusterDef()\n    worker_job = cluster_def.job.add()\n    worker_job.name = 'worker'\n    worker_job.tasks[0] = 'grpc://10.2.3.4:8470'\n    session_config = config_pb2.ConfigProto(cluster_def=cluster_def)\n    run_config = tpu_config.RunConfig(\n        session_config=session_config, master='grpc://10.2.3.4:8470')\n\n    ctx = self._create_ctx(run_config)\n    self.assertEqual('worker', ctx.master_job)\n\n  def test_cluster_spec_prop_multi_jobs(self):\n    cluster_def = cluster_pb2.ClusterDef()\n    worker_job = cluster_def.job.add()\n    worker_job.name = 'worker'\n    worker_job.tasks[0] = 'grpc://10.2.3.4:8470'\n    coordinator_job = cluster_def.job.add()\n    coordinator_job.name = 'coordinator'\n    coordinator_job.tasks[0] = 'grpc://10.2.3.4:8470'\n    session_config = config_pb2.ConfigProto(cluster_def=cluster_def)\n    run_config = tpu_config.RunConfig(\n        session_config=session_config, master='grpc://10.2.3.4:8470')\n\n    ctx = self._create_ctx(run_config)\n    self.assertEqual('worker', ctx.master_job)\n\n  def test_cluster_spec_prop_cannot_infer(self):\n    # No coordinator.\n    cluster_def = cluster_pb2.ClusterDef()\n    worker_job = cluster_def.job.add()\n    worker_job.name = 'worker'\n    worker_job.tasks[0] = 'grpc://10.2.3.4:8470'\n    coordinator_job = cluster_def.job.add()\n    coordinator_job.name = 'other_worker'\n    coordinator_job.tasks[0] = 'grpc://10.2.3.4:8470'\n    session_config = config_pb2.ConfigProto(cluster_def=cluster_def)\n    run_config = tpu_config.RunConfig(\n        session_config=session_config, master='grpc://10.2.3.4:8470')\n    with self.assertRaises(ValueError):\n      ctx = self._create_ctx(run_config)\n      ctx.master_job  # pylint:disable=pointless-statement\n\n    # 2 non-coordinator jobs.\n    cluster_def = cluster_pb2.ClusterDef()\n    worker_job = cluster_def.job.add()\n    worker_job.name = 'worker'\n    worker_job.tasks[0] = 'grpc://10.2.3.4:8470'\n    other_worker_job = cluster_def.job.add()\n    other_worker_job.name = 'other_worker'\n    other_worker_job.tasks[0] = 'grpc://10.2.3.5:8470'\n    coordinator_job = cluster_def.job.add()\n    coordinator_job.name = 'coordinator'\n    coordinator_job.tasks[0] = 'grpc://10.2.3.4:8470'\n    session_config = config_pb2.ConfigProto(cluster_def=cluster_def)\n    run_config = tpu_config.RunConfig(\n        session_config=session_config, master='grpc://10.2.3.4:8470')\n    with self.assertRaises(ValueError):\n      ctx = self._create_ctx(run_config)\n      ctx.master_job  # pylint:disable=pointless-statement\n\n  def test_session_config_none(self):\n    run_config = tpu_config.RunConfig()\n    self.assertIsNone(run_config.session_config)\n    ctx = self._create_ctx(run_config)\n    self.assertIsNone(ctx.master_job)\n\n    run_config = tpu_config.RunConfig(master='grpc://10.2.3.4:8470')\n    self.assertIsNone(run_config.session_config)\n    ctx = self._create_ctx(run_config)\n    self.assertEqual('tpu_worker', ctx.master_job)\n\n  def test_override_name(self):\n    tpu_cfg = tpu_config.TPUConfig(tpu_job_name='my_custom_job')\n    run_config = tpu_config.RunConfig(tpu_config=tpu_cfg)\n    ctx = self._create_ctx(run_config)\n    self.assertEqual('my_custom_job', ctx.master_job)\n\n  def test_evaluation_master(self):\n    run_config = tpu_config.RunConfig(master='grpc://10.2.3.4:8470')\n    self.assertEqual(run_config.master, run_config.evaluation_master)\n\n    run_config = tpu_config.RunConfig(\n        master='grpc://10.2.3.4:8470', evaluation_master='grpc://1.1.1.1:8470')\n    self.assertEqual('grpc://1.1.1.1:8470', run_config.evaluation_master)\n\n  def test_input_partition_config(self):\n    with self.assertRaisesRegex(ValueError,\n                                'input_partition_dims is.* PER_HOST_V2 mode.'):\n      tpu_config.TPUConfig(\n          num_shards=1, input_partition_dims=[[1, 2, 1, 1], None])\n\n    with self.assertRaisesRegex(ValueError,\n                                '.*requires setting num_cores_per_replica.'):\n      tpu_config.TPUConfig(\n          num_shards=1,\n          per_host_input_for_training=tpu_config.InputPipelineConfig\n          .PER_HOST_V2,\n          input_partition_dims=[[1, 2, 1, 1], None])\n\n    with self.assertRaisesRegex(ValueError, '.*with one or two elements.'):\n      tpu_config.TPUConfig(\n          num_shards=1,\n          per_host_input_for_training=tpu_config.InputPipelineConfig\n          .PER_HOST_V2,\n          input_partition_dims=[[1, 2, 1, 1], None, None])\n\n    tpu_config.TPUConfig(\n        num_shards=1,\n        num_cores_per_replica=2,\n        per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2,\n        input_partition_dims=[[1, 2, 1, 1], None])\n\n\nclass TPUEstimatorInputPartitionValidationTest(test.TestCase):\n\n  def _train(self,\n             iterations_per_loop,\n             image_height=224,\n             image_width=224,\n             steps=None,\n             num_shards=None,\n             num_cores_per_replica=None,\n             input_partition_dims=None):\n    \"\"\"Trains the model with InputPartition config.\"\"\"\n\n    def input_fn(params):\n      batch_size = params['batch_size']\n      x = np.random.normal(\n          size=[batch_size, image_height, image_width, 3]).astype(np.float32)\n      return dummy_input_fn_with_dataset(batch_size, repeat=True, x=x)\n\n    run_config = create_run_config(\n        iterations_per_loop=iterations_per_loop,\n        num_shards=num_shards,\n        num_cores_per_replica=num_cores_per_replica,\n        input_partition_dims=input_partition_dims,\n        per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=get_model_fn(),\n        config=run_config,\n        train_batch_size=128 * num_shards,\n        eval_batch_size=128 * num_shards)\n\n    est.train(input_fn, steps=steps, max_steps=None)\n\n  def test_train_with_non_positive_dims(self):\n    with self.assertRaisesRegex(ValueError,\n                                'All input partition dims must be >= 1.'):\n      self._train(\n          iterations_per_loop=2,\n          image_height=321,\n          image_width=224,\n          steps=2,\n          num_shards=1,\n          num_cores_per_replica=2,\n          input_partition_dims=[{\n              'x': [1, 2, 0, 1]\n          }, None])\n\n  def test_train_with_unmatched_partition_dims(self):\n    with self.assertRaisesRegex(\n        ValueError, 'The product of each input partition dim should '\n        'equal to num_cores_per_replica.*'):\n      self._train(\n          iterations_per_loop=2,\n          image_height=320,\n          image_width=224,\n          steps=2,\n          num_shards=1,\n          num_cores_per_replica=2,\n          input_partition_dims=[{\n              'x': [1, 2, 2, 1]\n          }, None])\n\n  def test_train_with_shape_unmatched_partition_dims(self):\n    with self.assertRaisesRegex(ValueError,\n                                'Input partition dims must have the same .*'):\n      self._train(\n          iterations_per_loop=2,\n          image_height=320,\n          image_width=224,\n          steps=2,\n          num_shards=1,\n          num_cores_per_replica=2,\n          input_partition_dims=[{\n              'x': [1, 2, 1]\n          }, None])\n\n  def test_train_with_unmatched_feature_keys(self):\n    with self.assertRaisesRegex(\n        ValueError, r'TPUConfig.input_partition_dims\\[0\\]'\n        ' mismatched feature .*'):\n      self._train(\n          iterations_per_loop=2,\n          image_height=320,\n          image_width=224,\n          steps=2,\n          num_shards=1,\n          num_cores_per_replica=2,\n          input_partition_dims=[{\n              'wrong_key': [1, 2, 1]\n          }, None])\n\n  def test_train_with_unmatched_label_keys(self):\n    with self.assertRaisesRegex(\n        ValueError, r'TPUConfig.input_partition_dims\\[1\\]'\n        ' mismatched label .*'):\n      self._train(\n          iterations_per_loop=2,\n          image_height=320,\n          image_width=224,\n          steps=2,\n          num_shards=1,\n          num_cores_per_replica=2,\n          input_partition_dims=[{\n              'x': [1, 2, 1, 1]\n          }, {\n              'wrong_key': None\n          }])\n\n  def test_train_uneven_partitions_successful(self):\n    # image_height=321, partitioned to 2 tensors with heights 161 and 160.\n    self._train(\n        iterations_per_loop=2,\n        image_height=321,\n        image_width=224,\n        steps=2,\n        num_shards=1,\n        num_cores_per_replica=2,\n        input_partition_dims=[{\n            'x': [1, 2, 1, 1]\n        }, None])\n\n  def test_uneven_partitions_computation(self):\n    image_height, image_width = 321, 224\n\n    def _predict_input_fn(params):\n      batch_size = params['batch_size']\n      x = np.random.normal(\n          size=[batch_size, image_height, image_width, 3]).astype(np.float32)\n      return dummy_input_fn_with_dataset(batch_size, repeat=False, x=x)\n\n    def _model_fn(features, labels, mode, params):\n      del params, labels\n      if mode == _PREDICT:\n        conv_output = layers.conv2d(features['x'], filters=1, kernel_size=3)\n        return tpu_estimator.TPUEstimatorSpec(\n            mode=mode, predictions={'predictions': conv_output})\n\n    run_config = create_run_config(\n        iterations_per_loop=2,\n        num_shards=1,\n        num_cores_per_replica=2,\n        input_partition_dims=[{\n            'x': [1, 2, 1, 1]\n        }],\n        per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=128,\n        predict_batch_size=1)\n    res = list(est.predict(_predict_input_fn))\n    self.assertEqual(len(res), 1)\n    self.assertEqual(res[0]['predictions'].shape, (319, 222, 1))\n\n  def _test_input_partitions_with_nested_label(self, input_partition_dims):\n    image_height, image_width = 224, 224\n\n    def _dummy_input_fn_with_dataset(dataset_size,\n                                     repeat=True,\n                                     x=None,\n                                     batch_size=None):\n      if batch_size is None:\n        batch_size = dataset_size\n      if x is None:\n        x = np.random.normal(size=[dataset_size, 1]).astype(np.float32)\n      labels = [[2.0]] * dataset_size\n\n      dataset1 = dataset_lib.Dataset.from_tensor_slices(x)\n      dataset2 = dataset_lib.Dataset.from_tensor_slices(labels)\n      dataset = dataset_lib.Dataset.zip((dataset1, dataset2))\n      if repeat:\n        dataset = dataset.repeat()\n      dataset = dataset.batch(batch_size, drop_remainder=True)\n\n      def _map(x, y):\n        return {'x': x}, {'label_1': {'label_2': y, 'label_3': y}, 'label_4': y}\n\n      return dataset.map(_map)\n\n    def _input_fn(params):\n      batch_size = params['batch_size']\n      x = np.random.normal(\n          size=[batch_size, image_height, image_width, 3]).astype(np.float32)\n      return _dummy_input_fn_with_dataset(batch_size, repeat=True, x=x)\n\n    def _model_fn(features, labels, mode, params):\n      del params\n      predictions = dense_computation(features)\n      loss = losses.mean_squared_error(labels['label_1']['label_3'],\n                                       predictions)\n      optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n          training.GradientDescentOptimizer(learning_rate=0.5))\n      train_op = optimizer.minimize(loss, training.get_global_step())\n\n      return tpu_estimator.TPUEstimatorSpec(mode, loss=loss, train_op=train_op)\n\n    run_config = create_run_config(\n        iterations_per_loop=2,\n        num_shards=1,\n        num_cores_per_replica=2,\n        input_partition_dims=input_partition_dims,\n        per_host_input_for_training=tpu_config.InputPipelineConfig.PER_HOST_V2)\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        config=run_config,\n        train_batch_size=128,\n        predict_batch_size=1)\n    est.train(_input_fn, steps=4, max_steps=None)\n\n  def test_fully_specified_input_partitions_with_with_nested_label(self):\n    self._test_input_partitions_with_nested_label([{'x': [1, 2, 1, 1]}, None])\n\n  def test_partial_specified_input_partitions_with_nested_label(self):\n    self._test_input_partitions_with_nested_label([{\n        'x': [1, 2, 1, 1]\n    }, {\n        'label_1': {\n            'label_2': None,\n            'label_3': None\n        },\n        'label_4': None\n    }])\n\n  def test_incorrect_input_partitions_with_nested_label(self):\n    with self.assertRaisesRegex(\n        ValueError, r'TPUConfig.input_partition_dims\\[1\\]'\n        ' mismatched the structure of labels. .*'):\n      self._test_input_partitions_with_nested_label([{\n          'x': [1, 2, 1, 1]\n      }, {\n          'label_1': None,\n          'label_4': None\n      }])\n\n\nclass TPUEstimatorInputPipelinePlacementTest(test.TestCase):\n\n  def _test_placement(self, per_host):\n\n    num_cores = 32\n    batch_sizes = []\n    global_batch_size = 1024\n    host_id_matcher = re.compile(r'^input_pipeline_task(\\d+)/(.*)$')\n    host_to_device = collections.defaultdict(list)\n\n    def _input_fn(params):\n      batch_sizes.append(params['batch_size'])\n      return dummy_input_fn(params['batch_size'])\n\n    def _model_fn(features, labels, mode, params):\n      # Examine the input pipeline placement.\n      operations = ops.get_default_graph().get_operations()\n      for op in operations:\n        result = host_id_matcher.match(op.name)\n        if result is None:\n          continue\n        # There is one op to read iterations_per_loop var (the Send node of\n        # tf.identity). It is colocated with global step. So, ignore here.\n        if result.group(2) == 'Identity/ReadVariableOp':\n          continue\n        host_id = int(result.group(1))\n        host_to_device[host_id].append(op.device)\n      return get_model_fn()(features, labels, mode, params)\n\n    run_config = tpu_config.RunConfig(\n        master='fake://123',\n        tpu_config=tpu_config.TPUConfig(\n            num_shards=num_cores, per_host_input_for_training=per_host))\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        train_batch_size=global_batch_size,\n        config=run_config)\n\n    old_value = tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP\n    tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP = True\n\n    try:\n      est.train(input_fn=_input_fn, steps=1)\n      self.fail('The train should not finish.')\n    except errors.NotFoundError:\n      # Expected. The TF sesion master is not valid.\n      pass\n\n    tpu_estimator._WRAP_INPUT_FN_INTO_WHILE_LOOP = old_value\n\n    expected_num_hosts = num_cores // 8\n    if per_host:\n      self.assertEqual(len(batch_sizes), expected_num_hosts)\n      self.assertEqual(batch_sizes[0], global_batch_size // expected_num_hosts)\n    else:\n      self.assertEqual(len(batch_sizes), num_cores)\n      self.assertEqual(batch_sizes[0], global_batch_size // num_cores)\n    self.assertEqual(expected_num_hosts, len(list(host_to_device.keys())))\n    for host_id in range(expected_num_hosts):\n      # On each host, all ops should be placed on the same device\n      device_set = set(host_to_device[host_id])\n      self.assertEqual(1, len(device_set))\n      self.assertEqual('/job:tpu_worker/task:{}/device:CPU:0'.format(host_id),\n                       host_to_device[host_id][0])\n\n  def _query_system(self, master_address, cluster_def, query_topology):\n    del master_address, cluster_def, query_topology\n    return tpu_system_metadata_lib.TPUSystemMetadata(\n        num_cores=32,\n        num_hosts=4,\n        num_of_cores_per_host=8,\n        topology=None,\n        devices=[])\n\n  def test_per_host_placement(self):\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n      self._test_placement(True)\n\n  def test_per_core_placement(self):\n    with test.mock.patch.object(\n        tpu_system_metadata_lib,\n        '_query_tpu_system_metadata',\n        side_effect=self._query_system):\n      self._test_placement(False)\n\n\nclass TPUEstimatorScaffoldTest(test.TestCase):\n\n  def _get_scaffold_fn(self, mode):\n\n    def _scaffold_fn_on_cpu():\n      scaffold = training.Scaffold()\n      finalize_fn = scaffold.finalize\n\n      def _finalize():\n        self.assertNotIn(mode, self.is_finalize_fn_called)\n        self.is_finalize_fn_called[mode] = True\n        return finalize_fn()\n\n      scaffold.finalize = _finalize\n      return scaffold\n\n    return _scaffold_fn_on_cpu\n\n  def _input_fn(self, params):\n    return dummy_input_fn(params['batch_size'])\n\n  def _predict_input_fn(self, params):\n    return dummy_input_fn_with_dataset(\n        dataset_size=params['batch_size'], repeat=False)\n\n  def _model_fn(self, features, labels, mode, config, params):\n    \"\"\"Creates a head returning `TPUEstimatorSpec` based on mode.\"\"\"\n    predictions = layers.dense(\n        features['x'], 1, kernel_initializer=init_ops.zeros_initializer())\n\n    eval_metrics = None\n    train_op = None\n    loss = None\n\n    if mode != _PREDICT:\n      loss = losses.mean_squared_error(labels, predictions)\n      if mode == _TRAIN:\n        optimizer = training.GradientDescentOptimizer(learning_rate=0.5)\n        if params['use_tpu']:\n          optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)\n        train_op = optimizer.minimize(\n            loss, global_step=training.get_global_step())\n      elif mode == _EVAL:\n\n        def _metric_fn_on_cpu(labels, predictions):\n          return {\n              'mse': metrics_lib.mean_absolute_error(labels, predictions),\n          }\n\n        eval_metrics = (_metric_fn_on_cpu, [labels, predictions])\n\n    return tpu_estimator.TPUEstimatorSpec(\n        mode=mode,\n        train_op=train_op,\n        loss=loss,\n        predictions={'x': predictions},\n        scaffold_fn=self._get_scaffold_fn(mode),\n        eval_metrics=eval_metrics)\n\n  def test_train(self):\n    for use_tpu in [True, False]:\n      self.is_finalize_fn_called = {}\n      est = tpu_estimator.TPUEstimator(\n          model_fn=self._model_fn,\n          train_batch_size=8,\n          config=create_run_config(iterations_per_loop=4),\n          use_tpu=use_tpu)\n      est.train(input_fn=self._input_fn, steps=1)\n      self.assertTrue(self.is_finalize_fn_called[_TRAIN])\n\n  def test_eval(self):\n    for use_tpu in [True, False]:\n      self.is_finalize_fn_called = {}\n      est = tpu_estimator.TPUEstimator(\n          model_fn=self._model_fn,\n          train_batch_size=8,\n          eval_batch_size=8,\n          config=create_run_config(iterations_per_loop=4),\n          use_tpu=use_tpu)\n\n      # Generate checkpoint.\n      est.train(input_fn=self._input_fn, steps=1)\n\n      est.evaluate(input_fn=self._input_fn, steps=1)\n      self.assertTrue(self.is_finalize_fn_called[_EVAL])\n\n  def test_predict(self):\n    for use_tpu in [True, False]:\n      self.is_finalize_fn_called = {}\n      est = tpu_estimator.TPUEstimator(\n          model_fn=self._model_fn,\n          train_batch_size=8,\n          predict_batch_size=8,\n          config=create_run_config(iterations_per_loop=4),\n          use_tpu=use_tpu)\n\n      # Generate checkpoint.\n      est.train(input_fn=self._input_fn, steps=1)\n      list(est.predict(input_fn=self._predict_input_fn))\n      self.assertTrue(self.is_finalize_fn_called[_PREDICT])\n\n  def test_scaffold_fn_capture_tpu_tensor(self):\n\n    def _model_fn(features, labels, mode, config, params):\n      \"\"\"Creates a head returning `TPUEstimatorSpec` based on mode.\"\"\"\n      del config, params\n      predictions = layers.dense(\n          features['x'], 1, kernel_initializer=init_ops.zeros_initializer())\n      loss = losses.mean_squared_error(labels, predictions)\n      optimizer = training.GradientDescentOptimizer(learning_rate=0.5)\n      optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)\n      train_op = optimizer.minimize(\n          loss, global_step=training.get_global_step())\n\n      def scaffold_fn():\n        summary_lib.scalar('loss_', loss)\n\n        return training.Scaffold()\n\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        train_batch_size=8,\n        config=create_run_config(iterations_per_loop=4))\n\n    with self.assertRaises(ValueError):\n      est.train(input_fn=self._input_fn, steps=1)\n\n  def test_scaffold_capture_tpu_tensor(self):\n\n    def _model_fn(features, labels, mode, config, params):\n      \"\"\"Creates a head returning `TPUEstimatorSpec` based on mode.\"\"\"\n      del config, params\n      predictions = layers.dense(\n          features['x'], 1, kernel_initializer=init_ops.zeros_initializer())\n      loss = losses.mean_squared_error(labels, predictions)\n      optimizer = training.GradientDescentOptimizer(learning_rate=0.5)\n      optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)\n      train_op = optimizer.minimize(\n          loss, global_step=training.get_global_step())\n\n      # Scaffold.finalize will \"merge\" all summaries, so we will be able\n      # to detect invalid TPU tensor capture.\n      summary_lib.scalar('loss_', loss)\n\n      def scaffold_fn():\n        return training.Scaffold()\n\n      return tpu_estimator.TPUEstimatorSpec(\n          mode=mode, train_op=train_op, loss=loss, scaffold_fn=scaffold_fn)\n\n    est = tpu_estimator.TPUEstimator(\n        model_fn=_model_fn,\n        train_batch_size=8,\n        config=create_run_config(iterations_per_loop=4))\n\n    with self.assertRaises(ValueError):\n      est.train(input_fn=self._input_fn, steps=1)\n\n\nclass TPUEstimatorScaffoldWithEMATest(test.TestCase):\n\n  def _get_scaffold(self, ema):\n    var_dict = ema.variables_to_restore()\n    return training.Scaffold(saver=training.Saver(var_dict))\n\n  def _input_fn(self, params):\n    return dummy_input_fn(params['batch_size'])\n\n  def _model_fn(self, features, labels, mode, config, params):\n    \"\"\"Creates a head returning `TPUEstimatorSpec` based on mode.\"\"\"\n    with variable_scope.variable_scope('foo'):\n      predictions = layers.dense(\n          features['x'], 1, kernel_initializer=init_ops.zeros_initializer())\n\n    eval_metrics = None\n    train_op = None\n\n    loss = losses.mean_squared_error(labels, predictions)\n    ema = moving_averages.ExponentialMovingAverage(decay=0.999)\n\n    if mode == _TRAIN:\n      optimizer = training.GradientDescentOptimizer(learning_rate=0.5)\n      optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer)\n\n      opt_op = optimizer.minimize(loss, global_step=training.get_global_step())\n\n      with ops.control_dependencies([opt_op]):\n        train_op = ema.apply()\n\n    elif mode == _EVAL:\n\n      def _metric_fn_on_cpu(labels, predictions):\n        return {\n            'mse': metrics_lib.mean_absolute_error(labels, predictions),\n        }\n\n      eval_metrics = (_metric_fn_on_cpu, [labels, predictions])\n\n    # Change the saver for non-training mode.\n    scaffold_fn = None if mode == _TRAIN else (lambda: self._get_scaffold(ema))\n\n    return tpu_estimator.TPUEstimatorSpec(\n        mode=mode,\n        train_op=train_op,\n        loss=loss,\n        predictions=predictions,\n        scaffold_fn=scaffold_fn,\n        eval_metrics=eval_metrics)\n\n  def test_ema_with_train_and_evaluate(self):\n    use_tpu = True\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn,\n        train_batch_size=8,\n        eval_batch_size=8,\n        config=create_run_config(iterations_per_loop=1),\n        use_tpu=use_tpu)\n\n    # With iterations_per_loop=1 and train steps = 2, the after_run in the hook\n    # will be invoked once to change the bias value. Make the bias variable\n    # super large here to avoid flaky.\n    rewrite_var_hook = _RewriteVarHook(\n        scope_name='foo', variable_name='dense/bias', value=[100])\n    est.train(input_fn=self._input_fn, steps=2, hooks=[rewrite_var_hook])\n\n    bias_value = est.get_variable_value('foo/dense/bias')\n    bias_ma_value = est.get_variable_value(\n        'foo/dense/bias/ExponentialMovingAverage')\n\n    self.assertNotAllClose(bias_value, bias_ma_value)\n\n    model_variable_value_hook = (\n        _ModelVariableValueHook(scope_name='foo', variable_name='dense/bias'))\n\n    est.evaluate(\n        input_fn=self._input_fn, steps=1, hooks=[model_variable_value_hook])\n\n    bias_value_during_eval = model_variable_value_hook.got_value\n    self.assertAlmostEqual(bias_ma_value, bias_value_during_eval)\n\n\nclass _ModelVariableValueHook(session_run_hook.SessionRunHook):\n  \"\"\"Capture the value of given variable after initialization.\"\"\"\n\n  def __init__(self, scope_name, variable_name):\n    \"\"\"Constructs the run hook.\"\"\"\n    self.scope_name = scope_name\n    self.variable_name = variable_name\n    self.got_value = None\n\n  def after_create_session(self, sess, coord):\n    del coord\n\n    with variable_scope.variable_scope(self.scope_name, reuse=True):\n      self.got_value = sess.run(variable_scope.get_variable(self.variable_name))\n\n\nclass _RewriteVarHook(session_run_hook.SessionRunHook):\n  \"\"\"Rwrite the variable value hook.\"\"\"\n\n  def __init__(self, scope_name, variable_name, value):\n    \"\"\"Constructs the run hook.\"\"\"\n    self.scope_name = scope_name\n    self.variable_name = variable_name\n    self.value = value\n\n  def begin(self):\n    with variable_scope.variable_scope(self.scope_name, reuse=True):\n      self._var = variable_scope.get_variable(self.variable_name)\n\n  def after_run(self, run_context, run_values):\n    self._var.load(self.value, session=run_context.session)\n\n\nclass TPUEstimatorHostCallTest(test.TestCase):\n\n  def _input_fn(self, params):\n    return dummy_input_fn(params['batch_size'])\n\n  def _host_call(self, model_dir, mode):\n\n    def fn(global_step, labels, predictions):\n      global_step = math_ops.cast(global_step[0], dtypes.int64)\n      # We add a filename suffix here to avoid clashing with existing summary\n      # creation in Estimator. Otherwise both may attempt to open the same\n      # filename.\n      #\n      # The name of the op is set to model_dir to avoid ResourceManager caching\n      # the same summary writer instance across tests.\n      #\n      # In addition, we give different suffixes for train and eval to avoid\n      # FileWriter in evaluate() overwrites the events dumped by training.\n      # This is because the event file path has timestamps at second accuracy\n      # but the CPU training could be super fast.\n      with tf.summary.create_file_writer(\n          model_dir,\n          filename_suffix='.TPUEstimator-{}'.format(1 if mode == model_fn_lib\n                                                    .ModeKeys.TRAIN else 2),\n          name=os.path.basename(model_dir)).as_default():\n        with summary_ops_v2.record_summaries_every_n_global_steps(\n            5 if mode == model_fn_lib.ModeKeys.TRAIN else 1,\n            global_step=global_step):\n          loss = losses.mean_squared_error(labels, predictions)\n          summary_ops_v2.scalar('host_call_test', loss, step=global_step)\n          summary_ops_v2.scalar(\n              'host_call_global_step', global_step, step=global_step)\n          return tf.compat.v1.summary.all_v2_summary_ops()\n\n    return fn\n\n  def _metric_fn_on_cpu(self, labels, predictions):\n    return {\n        'mse': metrics_lib.mean_absolute_error(labels, predictions),\n    }\n\n  def _model_fn(self, model_dir):\n\n    def fn(features, labels, mode, params):\n      del params\n      train_op = None\n      predictions = dense_computation(features)\n      loss = losses.mean_squared_error(labels, predictions)\n      if mode == _TRAIN:\n        optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n            training.GradientDescentOptimizer(learning_rate=0.5))\n        train_op = optimizer.minimize(loss, training.get_global_step())\n      return tpu_estimator.TPUEstimatorSpec(\n          mode,\n          loss=loss,\n          train_op=train_op,\n          predictions=predictions,\n          eval_metrics=(self._metric_fn_on_cpu, [labels, predictions]),\n          host_call=(self._host_call(model_dir, mode), [\n              array_ops.reshape(\n                  math_ops.cast(training.get_global_step(), dtypes.int32), [1]),\n              labels, predictions\n          ]))\n\n    return fn\n\n  def _events_from_logdir(self, logdir):\n    files = gfile.ListDirectory(logdir)\n    events = []\n    found = False\n    for f in sorted(files):\n      # Note that we need to distinguish between the TPUEstimator events file\n      # and the SummarySaverHook one.\n      if '.tfevents.' in f and '.TPUEstimator' in f:\n        found = True\n        f = os.path.join(logdir, f)\n        events.extend(events_from_file(f))\n    self.assertEqual(True, found)\n    return events\n\n  def _test_summaries(self, use_tpu, output_every_n_steps=False):\n    outfeed_every_n_steps = 2 if output_every_n_steps else 1\n    model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())\n    run_config = tpu_config.RunConfig(\n        master='',\n        model_dir=model_dir,\n        tpu_config=tpu_config.TPUConfig(\n            iterations_per_loop=21,\n            num_shards=FLAGS.test_num_shards,\n            experimental_host_call_every_n_steps=outfeed_every_n_steps,\n        ))\n    est = tpu_estimator.TPUEstimator(\n        model_fn=self._model_fn(model_dir),\n        train_batch_size=8,\n        eval_batch_size=8,\n        config=run_config,\n        use_tpu=use_tpu)\n\n    est.train(input_fn=self._input_fn, steps=42)\n    events = self._events_from_logdir(model_dir)\n    events = [e for e in events if e.WhichOneof('what') != 'file_version']\n    if not output_every_n_steps or not use_tpu:\n      self.assertEqual(18, len(events))\n      self.assertEqual(\n          9,\n          len([e for e in events\n               if e.summary.value[0].tag == 'host_call_test']))\n      self.assertEqual([value*5 for value in range(9)], [\n          e.summary.value[0].simple_value\n          for e in events\n          if e.summary.value[0].tag == 'host_call_global_step'])\n    else:\n      self.assertEqual(10, len(events))\n      self.assertEqual(\n          5,\n          len([e for e in events\n               if e.summary.value[0].tag == 'host_call_test']))\n      self.assertEqual([0, 10, 20, 25, 35], [\n          e.summary.value[0].simple_value\n          for e in events\n          if e.summary.value[0].tag == 'host_call_global_step'])\n\n    est.evaluate(input_fn=self._input_fn, steps=7)\n    events = self._events_from_logdir(model_dir)\n    events = [e for e in events if e.WhichOneof('what') != 'file_version']\n    if not output_every_n_steps or not use_tpu:\n      self.assertEqual(32, len(events))  # 18 from train + 14 from eval\n      self.assertEqual(\n          16,  # 9 from train + 7 from eval\n          len([e for e in events\n               if e.summary.value[0].tag == 'host_call_test']))\n      self.assertEqual(\n          [value*5 for value in range(9)] + [42] * 7,\n          [e.summary.value[0].simple_value\n           for e in events\n           if e.summary.value[0].tag == 'host_call_global_step'])\n    else:\n      self.assertEqual(24, len(events))\n      self.assertEqual(\n          12,\n          len([e for e in events\n               if e.summary.value[0].tag == 'host_call_test']))\n      self.assertEqual(\n          [0, 10, 20, 25, 35] + [42] * 7,\n          [e.summary.value[0].simple_value\n           for e in events\n           if e.summary.value[0].tag == 'host_call_global_step'])\n\n  def test_summaries(self):\n    self._test_summaries(True)\n\n  def test_summaries_on_cpu(self):\n    self._test_summaries(False)\n\n  def test_summaries_every_n_steps(self):\n    self._test_summaries(True, True)\n\n  def test_summaries_on_cpu_every_n_steps(self):\n    self._test_summaries(False, True)\n\n  def test_keras_tensorflow_op_layer(self):\n    def model_fn(features, labels, mode, params):\n      del features, labels, params\n      i1 = tf_keras.Input(10)\n      i2 = tf_keras.Input(10)\n      out = tf.concat([i1, i2], axis=1)\n      out = tf_keras.layers.Dense(1)(out)\n      model = tf_keras.Model([i1, i2], out)\n      x = [tf.ones((5, 10)), tf.ones((5, 10))]\n      y = model(x)\n      loss = tf.reduce_mean(y)\n      if mode == _TRAIN:\n        optimizer = tf.compat.v1.tpu.CrossShardOptimizer(\n            training.GradientDescentOptimizer(learning_rate=0.5))\n        train_op = optimizer.minimize(loss, training.get_global_step())\n      return tpu_estimator.TPUEstimatorSpec(\n          mode,\n          loss=loss,\n          train_op=train_op)\n\n    model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())\n    run_config = tpu_config.RunConfig(\n        master='',\n        model_dir=model_dir,\n        tpu_config=tpu_config.TPUConfig(\n            iterations_per_loop=1,\n            num_shards=FLAGS.test_num_shards,\n        ))\n    est = tpu_estimator.TPUEstimator(\n        model_fn=model_fn,\n        train_batch_size=8,\n        eval_batch_size=8,\n        config=run_config)\n    est.train(input_fn=self._input_fn, steps=42)\n\n\nif __name__ == '__main__':\n  tf.compat.v1.disable_v2_behavior()\n  test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/tpu/util.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ===================================================================\n\"\"\"Utilities for the functionalities.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport re\nimport time\nimport numpy as np\nimport six\nimport tensorflow as tf\n\n_ITERATIONS_PER_LOOP_VALUE_REGEX = re.compile(\n    r'^(?P<value>[1-9]\\d*)((?P<suffix>[s|m|h])$|$)')\n\nIterationsPerLoopCounter = collections.namedtuple('IterationsPerLoopCounter',\n                                                  ['value', 'unit'])\n\n\ndef check_positive_integer(value, name):\n  \"\"\"Checks whether `value` is a positive integer.\"\"\"\n  if not isinstance(value, (six.integer_types, np.integer)):\n    raise TypeError('{} must be int, got {}'.format(name, type(value)))\n\n  if value <= 0:\n    raise ValueError('{} must be positive, got {}'.format(name, value))\n\n\ndef parse_iterations_per_loop(iterations_per_loop):\n  \"\"\"Parses the `iterations_per_loop` value.\n\n  The parser expects the value of the `iterations_per_loop` value to be a\n  positive integer value with unit:`count` or time-based value `<N><s|m|h>`\n  where <N> is any positive integer and `s`, `m`, `h` are unit of time in\n  seconds, minutes, hours respectively. Examples of valid values: `3600s`, `60m`\n  , `1h`.\n\n  Args:\n    iterations_per_loop: Number of iterations or time alloted to spend on per\n      device loop.\n\n  Returns:\n    A dictionary of `value` and `unit`. The `unit` value can be either a raw\n    `count`, or time in `seconds`.\n    {\n      \"value\": <positive-integer>,\n      \"unit\": <unit: `count` | `seconds`>\n    }\n  \"\"\"\n  m = _ITERATIONS_PER_LOOP_VALUE_REGEX.match(str(iterations_per_loop))\n  if m is None:\n    raise ValueError(\n        'Invalid TPUConfig `iterations_per_loop` value. Value must be positive '\n        'integer value or time-based value `<N><s|m|h>` where <N> is any'\n        'positive integer and `s`, `m`, `h` are unit of time in seconds, '\n        'minutes, hours respectively. Examples of valid values: `3600s`, `60m`,'\n        ' `1h`.')\n  unit_value = 'seconds' if m.group('suffix') in ['h', 'm', 's'] else 'count'\n  value = int(m.group('value'))\n  if m.group('suffix') == 'm':\n    value *= 60\n  elif m.group('suffix') == 'h':\n    value *= 3600\n  return IterationsPerLoopCounter(value, unit_value)\n\n\n# TODO(b/118302029) Remove this copy of MultiHostDatasetInitializerHook after we\n# release a tensorflow_estimator with MultiHostDatasetInitializerHook in\n# python/estimator/util.py.\nclass MultiHostDatasetInitializerHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Creates a SessionRunHook that initializes all passed iterators.\"\"\"\n\n  def __init__(self, dataset_initializers):\n    self._initializers = dataset_initializers\n\n  def after_create_session(self, session, coord):\n    del coord\n    start = time.time()\n    session.run(self._initializers)\n    tf.compat.v1.logging.info('Initialized dataset iterators in %d seconds',\n                              time.time() - start)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/training.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Classes and functions related to train_and_evaluate.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport collections\nimport json\nimport os\nimport time\n\nimport six\nimport tensorflow as tf\nfrom tensorflow.python.distribute import estimator_training as distribute_coordinator_training\nfrom tensorflow.python.training import basic_session_run_hooks\nfrom tensorflow.python.training import server_lib\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import exporter as exporter_lib\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\nfrom tensorflow_estimator.python.estimator.estimator_export import estimator_export\n\n_MAX_DELAY_SECS = 60\n_DELAY_SECS_PER_WORKER = 5\n_TF_CONFIG_ENV = 'TF_CONFIG'\n_ENVIRONMENT_KEY = 'environment'\n_ENVIRONMENT_GOOGLE_VALUE = 'google'\n_TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER,\n                 run_config_lib.TaskType.WORKER)\n\n\ndef _validate_input_fn(input_fn):\n  \"\"\"Validates the `input_fn`.\"\"\"\n  if not callable(input_fn):\n    raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn))\n\n\ndef _validate_hooks(hooks):\n  \"\"\"Validates the `hooks`.\"\"\"\n  hooks = tuple(hooks or [])\n  for hook in hooks:\n    if not isinstance(hook, tf.compat.v1.train.SessionRunHook):\n      raise TypeError(\n          'All hooks must be `SessionRunHook` instances, given: {}'.format(\n              hook))\n  return hooks\n\n\ndef _validate_saving_listeners(saving_listeners):\n  \"\"\"Validates the `saving_listeners`.\"\"\"\n  saving_listeners = tuple(saving_listeners or [])\n  for saving_listener in saving_listeners:\n    if not isinstance(saving_listener,\n                      tf.compat.v1.train.CheckpointSaverListener):\n      raise TypeError(\n          'All saving_listeners must be `CheckpointSaverListener` instances, '\n          'given: {}'.format(saving_listener))\n  return saving_listeners\n\n\ndef _validate_exporters(exporters):\n  \"\"\"Validates `exporters` and returns them as a tuple.\"\"\"\n  if not exporters:\n    return ()\n\n  if isinstance(exporters, exporter_lib.Exporter):\n    exporters = [exporters]\n\n  unique_names = []  # `Exporter`s should have unique names.\n  try:\n    for exporter in exporters:\n      if not isinstance(exporter, exporter_lib.Exporter):\n        # Error message will be printed out by the outer try/except.\n        raise TypeError\n\n      if not exporter.name:\n        full_list_of_names = [e.name for e in exporters]\n        raise ValueError('An Exporter cannot have a name that is `None` or'\n                         ' empty. All exporter names:'\n                         ' {}'.format(full_list_of_names))\n\n      if not isinstance(exporter.name, six.string_types):\n        raise ValueError('An Exporter must have a string name. Given: '\n                         '{}'.format(type(exporter.name)))\n\n      if exporter.name in unique_names:\n        full_list_of_names = [e.name for e in exporters]\n        raise ValueError(\n            '`exporters` must have unique names. Such a name cannot be `None`.'\n            ' All exporter names: {}'.format(full_list_of_names))\n      unique_names.append(exporter.name)\n  except TypeError:\n    # Two possibilities:\n    # - `exporters` is neither `Exporter` nor iterable.  Python has\n    #   raised a `TypeError` when iterating over `exporters`.\n    # - an `exporter` was None or not of type `Exporter`, so we raised a\n    #   `TypeError`.\n    raise TypeError('`exporters` must be an Exporter,'\n                    ' an iterable of Exporter, or `None`,'\n                    ' found %s.' % exporters)\n\n  return tuple(exporters)\n\n\ndef _is_google_env():\n  \"\"\"Detects whether current environment is google.\"\"\"\n  tf_config = json.loads(os.environ.get(_TF_CONFIG_ENV) or '{}')\n  if not tf_config:\n    tf.compat.v1.logging.warn(\n        'TF_CONFIG should not be empty in distributed environment.')\n  return tf_config.get(_ENVIRONMENT_KEY) == _ENVIRONMENT_GOOGLE_VALUE\n\n\n@estimator_export('estimator.TrainSpec')\nclass TrainSpec(\n    collections.namedtuple(\n        'TrainSpec', ['input_fn', 'max_steps', 'hooks', 'saving_listeners'])):\n  \"\"\"Configuration for the \"train\" part for the `train_and_evaluate` call.\n\n  `TrainSpec` determines the input data for the training, as well as the\n  duration. Optional hooks run at various stages of training.\n\n  Usage:\n\n  >>> train_spec = tf.estimator.TrainSpec(\n  ...    input_fn=lambda: 1,\n  ...    max_steps=100,\n  ...    hooks=[_StopAtSecsHook(stop_after_secs=10)],\n  ...    saving_listeners=[_NewCheckpointListenerForEvaluate(None, 20, None)])\n  >>> train_spec.saving_listeners[0]._eval_throttle_secs\n  20\n  >>> train_spec.hooks[0]._stop_after_secs\n  10\n  >>> train_spec.max_steps\n  100\n  \"\"\"\n\n  def __new__(cls, input_fn, max_steps=None, hooks=None, saving_listeners=None):\n    \"\"\"Creates a validated `TrainSpec` instance.\n\n    Args:\n      input_fn: A function that provides input data for training as minibatches.\n        See [Premade Estimators](\n        https://tensorflow.org/guide/premade_estimators#create_input_functions)\n          for more information. The function should construct and return one of\n        the following:\n          * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a\n            tuple (features, labels) with same constraints as below.\n          * A tuple (features, labels): Where features is a `Tensor` or a\n            dictionary of string feature name to `Tensor` and labels is a\n            `Tensor` or a dictionary of string label name to `Tensor`.\n      max_steps: Int. Positive number of total steps for which to train model.\n        If `None`, train forever. The training `input_fn` is not expected to\n        generate `OutOfRangeError` or `StopIteration` exceptions. See the\n        `train_and_evaluate` stop condition section for details.\n      hooks: Iterable of `tf.train.SessionRunHook` objects to run on all workers\n        (including chief) during training.\n      saving_listeners: Iterable of `tf.estimator.CheckpointSaverListener`\n        objects to run on chief during training.\n\n    Returns:\n      A validated `TrainSpec` object.\n\n    Raises:\n      ValueError: If any of the input arguments is invalid.\n      TypeError: If any of the arguments is not of the expected type.\n    \"\"\"\n    # Validate input_fn.\n    _validate_input_fn(input_fn)\n\n    # Validate max_steps.\n    if max_steps is not None and max_steps <= 0:\n      raise ValueError(\n          'Must specify max_steps > 0, given: {}'.format(max_steps))\n\n    # Validate hooks.\n    hooks = _validate_hooks(hooks)\n\n    # Validate saving_listeners.\n    saving_listeners = _validate_saving_listeners(saving_listeners)\n\n    return super(TrainSpec, cls).__new__(\n        cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks,\n        saving_listeners=saving_listeners)\n\n\n@estimator_export('estimator.EvalSpec')\nclass EvalSpec(\n    collections.namedtuple('EvalSpec', [\n        'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',\n        'throttle_secs'\n    ])):\n  \"\"\"Configuration for the \"eval\" part for the `train_and_evaluate` call.\n\n  `EvalSpec` combines details of evaluation of the trained model as well as its\n  export. Evaluation consists of computing metrics to judge the performance of\n  the trained model.  Export writes out the trained model on to external\n  storage.\n  \"\"\"\n\n  def __new__(cls,\n              input_fn,\n              steps=100,\n              name=None,\n              hooks=None,\n              exporters=None,\n              start_delay_secs=120,\n              throttle_secs=600):\n    \"\"\"Creates a validated `EvalSpec` instance.\n\n    Args:\n      input_fn: A function that constructs the input data for evaluation. See\n        [Premade Estimators](\n        https://tensorflow.org/guide/premade_estimators#create_input_functions)\n          for more information. The function should construct and return one of\n        the following:\n          * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a\n            tuple (features, labels) with same constraints as below.\n          * A tuple (features, labels): Where features is a `Tensor` or a\n            dictionary of string feature name to `Tensor` and labels is a\n            `Tensor` or a dictionary of string label name to `Tensor`.\n      steps: Int. Positive number of steps for which to evaluate model. If\n        `None`, evaluates until `input_fn` raises an end-of-input exception. See\n        `Estimator.evaluate` for details.\n      name: String. Name of the evaluation if user needs to run multiple\n        evaluations on different data sets. Metrics for different evaluations\n        are saved in separate folders, and appear separately in tensorboard.\n      hooks: Iterable of `tf.train.SessionRunHook` objects to run during\n        evaluation.\n      exporters: Iterable of `Exporter`s, or a single one, or `None`.\n        `exporters` will be invoked after each evaluation.\n      start_delay_secs: Int. Start evaluating after waiting for this many\n        seconds.\n      throttle_secs: Int. Do not re-evaluate unless the last evaluation was\n        started at least this many seconds ago. Of course, evaluation does not\n        occur if no new checkpoints are available, hence, this is the minimum.\n\n    Returns:\n      A validated `EvalSpec` object.\n\n    Raises:\n      ValueError: If any of the input arguments is invalid.\n      TypeError: If any of the arguments is not of the expected type.\n    \"\"\"\n    # Validate input_fn.\n    _validate_input_fn(input_fn)\n\n    # Validate steps.\n    if steps is not None and steps <= 0:\n      raise ValueError('Must specify steps > 0, given: {}'.format(steps))\n\n    # Validate name.\n    if name is not None and not isinstance(name, six.string_types):\n      raise TypeError('`name` must be string, given: {}'.format(name))\n\n    # Validate hooks.\n    hooks = _validate_hooks(hooks)\n\n    # Validate exporters.\n    exporters = _validate_exporters(exporters)\n\n    # Validate start_delay_secs.\n    if start_delay_secs < 0:\n      raise ValueError('Must specify start_delay_secs >= 0, given: {}'.format(\n          start_delay_secs))\n\n    # Validate throttle_secs.\n    if throttle_secs < 0:\n      raise ValueError(\n          'Must specify throttle_secs >= 0, given: {}'.format(throttle_secs))\n\n    return super(EvalSpec, cls).__new__(\n        cls,\n        input_fn=input_fn,\n        steps=steps,\n        name=name,\n        hooks=hooks,\n        exporters=exporters,\n        start_delay_secs=start_delay_secs,\n        throttle_secs=throttle_secs)\n\n\n@estimator_export('estimator.train_and_evaluate')\ndef train_and_evaluate(estimator, train_spec, eval_spec):\n  \"\"\"Train and evaluate the `estimator`.\n\n  This utility function trains, evaluates, and (optionally) exports the model by\n  using the given `estimator`. All training related specification is held in\n  `train_spec`, including training `input_fn` and training max steps, etc. All\n  evaluation and export related specification is held in `eval_spec`, including\n  evaluation `input_fn`, steps, etc.\n\n  This utility function provides consistent behavior for both local\n  (non-distributed) and distributed configurations. The default distribution\n  configuration is parameter server-based between-graph replication. For other\n  types of distribution configurations such as all-reduce training, please use\n  [DistributionStrategies](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/distribute).\n\n  Overfitting: In order to avoid overfitting, it is recommended to set up the\n  training `input_fn` to shuffle the training data properly.\n\n  Stop condition: In order to support both distributed and non-distributed\n  configuration reliably, the only supported stop condition for model\n  training is `train_spec.max_steps`. If `train_spec.max_steps` is `None`, the\n  model is trained forever. *Use with care* if model stop condition is\n  different. For example, assume that the model is expected to be trained with\n  one epoch of training data, and the training `input_fn` is configured to throw\n  `OutOfRangeError` after going through one epoch, which stops the\n  `Estimator.train`. For a three-training-worker distributed configuration, each\n  training worker is likely to go through the whole epoch independently. So, the\n  model will be trained with three epochs of training data instead of one epoch.\n\n  Example of local (non-distributed) training:\n\n  ```python\n  # Set up feature columns.\n  categorial_feature_a = categorial_column_with_hash_bucket(...)\n  categorial_feature_a_emb = embedding_column(\n      categorical_column=categorial_feature_a, ...)\n  ...  # other feature columns\n\n  estimator = DNNClassifier(\n      feature_columns=[categorial_feature_a_emb, ...],\n      hidden_units=[1024, 512, 256])\n\n  # Or set up the model directory\n  #   estimator = DNNClassifier(\n  #       config=tf.estimator.RunConfig(\n  #           model_dir='/my_model', save_summary_steps=100),\n  #       feature_columns=[categorial_feature_a_emb, ...],\n  #       hidden_units=[1024, 512, 256])\n\n  # Input pipeline for train and evaluate.\n  def train_input_fn(): # returns x, y\n    # please shuffle the data.\n    pass\n  def eval_input_fn(): # returns x, y\n    pass\n\n  train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=1000)\n  eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)\n\n  tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)\n  ```\n  Note that in current implementation `estimator.evaluate` will be called\n  multiple times. This means that evaluation graph (including eval_input_fn)\n  will be re-created for each `evaluate` call. `estimator.train` will be called\n  only once.\n\n  Example of distributed training:\n\n  Regarding the example of distributed training, the code above can be used\n  without a change (Please do make sure that the `RunConfig.model_dir` for all\n  workers is set to the same directory, i.e., a shared file system all workers\n  can read and write). The only extra work to do is setting the environment\n  variable `TF_CONFIG` properly for each worker correspondingly.\n\n  Also see\n  [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).\n\n  Setting environment variable depends on the platform. For example, on Linux,\n  it can be done as follows (`$` is the shell prompt):\n\n  ```\n  $ TF_CONFIG='<replace_with_real_content>' python train_model.py\n  ```\n\n  For the content in `TF_CONFIG`, assume that the training cluster spec looks\n  like:\n\n  ```\n  cluster = {\"chief\": [\"host0:2222\"],\n             \"worker\": [\"host1:2222\", \"host2:2222\", \"host3:2222\"],\n             \"ps\": [\"host4:2222\", \"host5:2222\"]}\n  ```\n\n  Example of `TF_CONFIG` for chief training worker (must have one and only one):\n\n  ```\n  # This should be a JSON string, which is set as environment variable. Usually\n  # the cluster manager handles that.\n  TF_CONFIG='{\n      \"cluster\": {\n          \"chief\": [\"host0:2222\"],\n          \"worker\": [\"host1:2222\", \"host2:2222\", \"host3:2222\"],\n          \"ps\": [\"host4:2222\", \"host5:2222\"]\n      },\n      \"task\": {\"type\": \"chief\", \"index\": 0}\n  }'\n  ```\n  Note that the chief worker also does the model training job, similar to other\n  non-chief training workers (see next paragraph). In addition to the model\n  training, it manages some extra work, e.g., checkpoint saving and restoring,\n  writing summaries, etc.\n\n  Example of `TF_CONFIG` for non-chief training worker (optional, could be\n  multiple):\n\n  ```\n  # This should be a JSON string, which is set as environment variable. Usually\n  # the cluster manager handles that.\n  TF_CONFIG='{\n      \"cluster\": {\n          \"chief\": [\"host0:2222\"],\n          \"worker\": [\"host1:2222\", \"host2:2222\", \"host3:2222\"],\n          \"ps\": [\"host4:2222\", \"host5:2222\"]\n      },\n      \"task\": {\"type\": \"worker\", \"index\": 0}\n  }'\n  ```\n  where the `task.index` should be set as 0, 1, 2, in this example, respectively\n  for non-chief training workers.\n\n  Example of `TF_CONFIG` for parameter server, aka ps (could be multiple):\n\n  ```\n  # This should be a JSON string, which is set as environment variable. Usually\n  # the cluster manager handles that.\n  TF_CONFIG='{\n      \"cluster\": {\n          \"chief\": [\"host0:2222\"],\n          \"worker\": [\"host1:2222\", \"host2:2222\", \"host3:2222\"],\n          \"ps\": [\"host4:2222\", \"host5:2222\"]\n      },\n      \"task\": {\"type\": \"ps\", \"index\": 0}\n  }'\n  ```\n  where the `task.index` should be set as 0 and 1, in this example, respectively\n  for parameter servers.\n\n  Example of `TF_CONFIG` for evaluator task. Evaluator is a special task that is\n  not part of the training cluster. There could be only one. It is used for\n  model evaluation.\n\n  ```\n  # This should be a JSON string, which is set as environment variable. Usually\n  # the cluster manager handles that.\n  TF_CONFIG='{\n      \"cluster\": {\n          \"chief\": [\"host0:2222\"],\n          \"worker\": [\"host1:2222\", \"host2:2222\", \"host3:2222\"],\n          \"ps\": [\"host4:2222\", \"host5:2222\"]\n      },\n      \"task\": {\"type\": \"evaluator\", \"index\": 0}\n  }'\n  ```\n\n  When `distribute` or `experimental_distribute.train_distribute` and\n  `experimental_distribute.remote_cluster` is set, this method will start a\n  client running on the current host which connects to the `remote_cluster` for\n  training and evaluation.\n\n  Args:\n    estimator: An `Estimator` instance to train and evaluate.\n    train_spec: A `TrainSpec` instance to specify the training specification.\n    eval_spec: A `EvalSpec` instance to specify the evaluation and export\n      specification.\n\n  Returns:\n    A tuple of the result of the `evaluate` call to the `Estimator` and the\n    export results using the specified `Exporter`s.\n    Currently, the return value is undefined for distributed training mode.\n\n  Raises:\n    ValueError: if environment variable `TF_CONFIG` is incorrectly set.\n  \"\"\"\n  _assert_eval_spec(eval_spec)  # fail fast if eval_spec is invalid.\n  estimator_lib._estimator_api_gauge.get_cell('train_and_evaluate').set(True)  # pylint: disable=protected-access\n\n  executor = _TrainingExecutor(\n      estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)\n  config = estimator.config\n\n  # If `distribute_coordinator_mode` is set and running in distributed\n  # environment, we run `train_and_evaluate` via distribute coordinator.\n  if distribute_coordinator_training.should_run_distribute_coordinator(config):\n    tf.compat.v1.logging.info(\n        'Running `train_and_evaluate` with Distribute Coordinator.')\n    distribute_coordinator_training.train_and_evaluate(estimator, train_spec,\n                                                       eval_spec,\n                                                       _TrainingExecutor)\n    return\n\n  if (config.task_type == run_config_lib.TaskType.EVALUATOR and\n      config.task_id > 0):\n    raise ValueError(\n        'For distributed training, there can only be one `evaluator` task '\n        '(with task id 0).  Given task id {}'.format(config.task_id))\n\n  return executor.run()\n\n\nclass _StopAtSecsHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Stops given secs after begin is called.\"\"\"\n\n  def __init__(self, stop_after_secs):\n    self._stop_after_secs = stop_after_secs\n    self._start_time = None\n\n  def begin(self):\n    self._start_time = time.time()\n\n  def after_run(self, run_context, run_values):\n    del run_values\n    if time.time() - self._start_time >= self._stop_after_secs:\n      run_context.request_stop()\n\n\nclass _NewCheckpointListenerForEvaluate(\n    tf.compat.v1.train.CheckpointSaverListener):\n  \"\"\"A saver listener to run evaluate with every checkpoint.\"\"\"\n\n  def __init__(self, evaluator, eval_throttle_secs, continuous_eval_listener):\n    self._evaluator = evaluator\n    self._eval_throttle_secs = eval_throttle_secs\n    self._continuous_eval_listener = continuous_eval_listener\n    self.eval_result, self.export_results = None, None\n\n  def begin(self):\n    self._timer = basic_session_run_hooks.SecondOrStepTimer(\n        every_secs=self._eval_throttle_secs)\n    self._is_first_run = True\n\n  def after_save(self, session, global_step_value):\n    del session  # unused; required by signature.\n    # skip first run model is not trained yet.\n    if self._is_first_run:\n      self._is_first_run = False\n      return\n\n    if not self._continuous_eval_listener.before_eval():\n      tf.compat.v1.logging.info(\n          'Exiting training and evaluation loop, as requested by '\n          '_ContinuousEvalListener.before_eval.')\n      return True\n    if self._timer.should_trigger_for_step(global_step_value):\n      self._evaluate(global_step_value)  # updates self.eval_result\n      if not self._continuous_eval_listener.after_eval(self.eval_result):\n        tf.compat.v1.logging.info('Exiting evaluation, as requested by '\n                                  '_ContinuousEvalListener.after_eval.')\n        return True\n    else:\n      # TODO(ispir): add remaining time in the log.\n      tf.compat.v1.logging.info(\n          'Skip the current checkpoint eval due to throttle secs '\n          '({} secs).'.format(self._eval_throttle_secs))\n\n  def end(self, session, global_step_value):\n    # Evaluate if the last step has not been evaluated, yet.\n    if global_step_value != self._timer.last_triggered_step():\n      if self._continuous_eval_listener.before_eval():\n        self._evaluate(global_step_value)\n        self._continuous_eval_listener.after_eval(self.eval_result)\n\n  def _evaluate(self, global_step_value):\n    self._timer.update_last_triggered_step(global_step_value)\n    self.eval_result, self.export_results = (\n        self._evaluator.evaluate_and_export())\n    if self.eval_result.status != _EvalStatus.EVALUATED:\n      #  This is unexpected; should never happen.\n      #  Training should always end with a new checkpoint.\n      raise RuntimeError('There was no new checkpoint after the training. '\n                         'Eval status: {}'.format(self.eval_result.status))\n\n\nclass _TrainingExecutor(object):\n  \"\"\"The executor to run `Estimator` training and evaluation.\n\n  This implementation supports both distributed and non-distributed (aka local)\n  training and evaluation based on the setting in `tf.estimator.RunConfig`.\n  \"\"\"\n\n  def __init__(self,\n               estimator,\n               train_spec,\n               eval_spec,\n               train_hooks=None,\n               continuous_eval_listener=None):\n    if not isinstance(estimator,\n                      (estimator_lib.Estimator, estimator_lib.EstimatorV2)):\n      raise TypeError('`estimator` must have type `tf.estimator.Estimator`. '\n                      'Got: {}'.format(type(estimator)))\n    self._estimator = estimator\n\n    if not isinstance(train_spec, TrainSpec):\n      raise TypeError('`train_spec` must have type `tf.estimator.TrainSpec`. '\n                      'Got: {}'.format(type(train_spec)))\n    self._train_spec = train_spec\n\n    if eval_spec and not isinstance(eval_spec, EvalSpec):\n      raise TypeError('`eval_spec` must be either `None` or have type '\n                      '`tf.estimator.EvalSpec`. Got: {}'.format(\n                          type(eval_spec)))\n    self._eval_spec = eval_spec\n\n    self._train_hooks = _validate_hooks(train_hooks)\n\n    if (continuous_eval_listener and\n        not isinstance(continuous_eval_listener, _ContinuousEvalListener)):\n      raise TypeError('`continuous_eval_listener` must have type '\n                      '`_ContinuousEvalListener`.')\n    self._continuous_eval_listener = (\n        continuous_eval_listener or _ContinuousEvalListener())\n\n  @property\n  def estimator(self):\n    return self._estimator\n\n  def run(self):\n    \"\"\"Executes the run_foo for task type `foo`.\n\n    `_TrainingExecutor` predefines the procedure for task type 'chief',\n    'worker', 'ps', and 'evaluator'. For task type `foo`, the corresponding\n    procedure is `run_foo'. This `run` method invoke the procedure base on the\n    `RunConfig.task_type`.\n\n    Returns:\n      A tuple of the result of the `evaluate` call to the `Estimator` and the\n      export results using the specified `ExportStrategy`.\n      Currently undefined for distributed training mode.\n\n    Raises:\n      ValueError: if the estimator.config is mis-configured.\n    \"\"\"\n    config = self._estimator.config\n\n    if (not config.cluster_spec and\n        config.task_type != run_config_lib.TaskType.EVALUATOR):\n      tf.compat.v1.logging.info(\n          'Running training and evaluation locally (non-distributed).')\n      return self.run_local()\n\n    # Distributed case.\n    if not config.task_type:\n      # TODO(xiejw): Improve the error message about how to set the TF_CONFIG\n      # correctly.\n      raise ValueError(\n          '`estimator.config` must have task_type set. This usually means '\n          'TF_CONFIG environment is not set correctly.')\n\n    if config.task_type == 'local':\n      raise ValueError(\n          '`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '\n          '`task` properties in TF_CONFIG absent triggers train and evaluate '\n          '`Estimator` locally (non-distributed).')\n\n    # For task type foo, call executor.run_foo.\n    available_tasks = [\n        x for x in dir(self) if x.startswith('run_') and x != 'run_local' and\n        callable(getattr(self, x))\n    ]\n    task_to_run = 'run_' + config.task_type\n    if task_to_run not in available_tasks:\n      raise ValueError(\n          'Task type {} is not supported. Supported task types are {}'.format(\n              config.task_type, [x[len('run_'):] for x in available_tasks]))\n    getattr(self, task_to_run)()\n\n  def run_chief(self):\n    \"\"\"Runs task chief.\"\"\"\n    # TODO(xiejw): To allow execution framework to add train hooks.\n    return self._start_distributed_training(\n        saving_listeners=self._train_spec.saving_listeners)\n\n  def run_worker(self):\n    \"\"\"Runs task (training) worker.\"\"\"\n    # TODO(xiejw): To allow execution framework to add train hooks.\n    return self._start_distributed_training()\n\n  def run_master(self):\n    \"\"\"Runs task master.\"\"\"\n    _assert_eval_spec(self._eval_spec)\n\n    # Final export signal: For any eval result with global_step >= train\n    # max_steps, the evaluator will send the final export signal. There is a\n    # small chance that the Estimator.train stopping logic sees a different\n    # global_step value (due to global step race condition and the fact the\n    # saver sees a larger value for checkpoint saving), which does not end\n    # the training. When the training ends, a new checkpoint is generated, which\n    # triggers the listener again. So, it could be the case the final export is\n    # triggered twice.\n    #\n    # But here, throttle_secs will skip the next intermediate checkpoint and,\n    # so, the double final export chance is very small.\n    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,\n                                             self._train_spec.max_steps)\n\n    # When the underlying `Estimator` object saves a new checkpoint, we would\n    # like this callback to be called so that evaluation and export can trigger.\n    saving_listeners = self._train_spec.saving_listeners + tuple(\n        [_NewCheckpointListenerForEvaluate(evaluator,\n                                           self._eval_spec.throttle_secs,\n                                           _ContinuousEvalListener())])\n    self._start_distributed_training(saving_listeners=saving_listeners)\n\n  def run_evaluator(self):\n    \"\"\"Runs task evaluator.\"\"\"\n    # TODO(xiejw): To allow execution framework to add continuous eval listener.\n    return self._start_continuous_evaluation()\n\n  def run_ps(self):\n    \"\"\"Runs task parameter server (in training cluster spec).\"\"\"\n    config = self._estimator.config\n    server = self._start_std_server(config)\n    server.join()\n\n  def run_local(self):\n    \"\"\"Runs training and evaluation locally (non-distributed).\"\"\"\n    _assert_eval_spec(self._eval_spec)\n\n    train_hooks = list(self._train_spec.hooks) + list(self._train_hooks)\n    tf.compat.v1.logging.info(\n        'Start train and evaluate loop. The evaluate will happen '\n        'after every checkpoint. Checkpoint frequency is determined '\n        'based on RunConfig arguments: save_checkpoints_steps {} or '\n        'save_checkpoints_secs {}.'.format(\n            self._estimator.config.save_checkpoints_steps,\n            self._estimator.config.save_checkpoints_secs))\n\n    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,\n                                             self._train_spec.max_steps)\n\n    listener_for_eval = _NewCheckpointListenerForEvaluate(\n        evaluator, self._eval_spec.throttle_secs,\n        self._continuous_eval_listener)\n    saving_listeners = self._train_spec.saving_listeners + (listener_for_eval,)\n\n    self._estimator.train(\n        input_fn=self._train_spec.input_fn,\n        max_steps=self._train_spec.max_steps,\n        hooks=train_hooks,\n        saving_listeners=saving_listeners)\n\n    eval_result = listener_for_eval.eval_result or _EvalResult(\n        status=_EvalStatus.MISSING_CHECKPOINT)\n    return eval_result.metrics, listener_for_eval.export_results\n\n  def _start_std_server(self, config):\n    \"\"\"Creates, starts, and returns a server_lib.Server.\"\"\"\n    if (not config.cluster_spec or not config.task_type or\n        config.task_id is None):\n      raise RuntimeError('Could not start server; be sure to specify '\n                         'cluster_spec, task_type, and task in '\n                         'RunConfig or set the TF_CONFIG environment variable.')\n\n    if not config.master:\n      jobs = config.cluster_spec.jobs\n      if (len(jobs) == 1 and\n          len(config.cluster_spec.job_tasks(jobs[0])) == 1 and\n          config.task_type in _TRAINER_JOBS):\n        # For distributed training, config.master is empty if and only if it has\n        # a single node in the cluster spec. In this case, we should not start\n        # the server.\n        tf.compat.v1.logging.info(\n            'Skip starting Tensorflow server as there is only one '\n            'node in the cluster.')\n        return\n      else:\n        raise RuntimeError(\n            'Could not start server; be sure to specify master in '\n            'RunConfig or set the TF_CONFIG environment variable.')\n\n    tf.compat.v1.logging.info('Start Tensorflow server.')\n\n    if config.session_config is None:\n      session_config = tf.compat.v1.ConfigProto(log_device_placement=False)\n    else:\n      session_config = tf.compat.v1.ConfigProto(\n          log_device_placement=False,\n          gpu_options=config.session_config.gpu_options)\n\n    server = server_lib.Server(\n        config.cluster_spec,\n        job_name=config.task_type,\n        task_index=config.task_id,\n        config=session_config,\n        start=False,\n        protocol=config.protocol)\n    server.start()\n    return server\n\n  def _start_distributed_training(self, saving_listeners=None):\n    \"\"\"Calls `Estimator` train in a distributed setting.\"\"\"\n    config = self._estimator.config\n\n    # Start in-process TensorFlow server if needed. It's important to start the\n    # server before we (optionally) sleep. Otherwise, the servers will wait to\n    # connect to each other before starting to train.\n    if not _is_google_env():\n      self._start_std_server(config)\n\n    # Delay worker to start. For asynchronous training, this usually helps model\n    # to converge faster.  Chief starts the training immediately, so, worker\n    # with task id x (0-based) should wait (x+1) * _DELAY_SECS_PER_WORKER.\n    start_delay_secs = 0\n    if config.task_type == run_config_lib.TaskType.WORKER:\n      # TODO(xiejw): Replace the hard code logic (task_id + 1) with unique id in\n      # training cluster.\n\n      max_delay_secs = _MAX_DELAY_SECS\n      if config.experimental_max_worker_delay_secs is not None:\n        max_delay_secs = int(config.experimental_max_worker_delay_secs)\n\n      start_delay_secs = min(max_delay_secs,\n                             (config.task_id + 1) * _DELAY_SECS_PER_WORKER)\n    if start_delay_secs > 0:\n      tf.compat.v1.logging.info('Waiting %d secs before starting training.',\n                                start_delay_secs)\n      time.sleep(start_delay_secs)\n\n    self._estimator.train(\n        input_fn=self._train_spec.input_fn,\n        max_steps=self._train_spec.max_steps,\n        hooks=list(self._train_spec.hooks) + list(self._train_hooks),\n        saving_listeners=saving_listeners)\n\n  def _start_continuous_evaluation(self):\n    \"\"\"Repeatedly calls `Estimator` evaluate and export until training ends.\"\"\"\n\n    _assert_eval_spec(self._eval_spec)\n\n    start_delay_secs = self._eval_spec.start_delay_secs\n    if start_delay_secs:\n      tf.compat.v1.logging.info('Waiting %f secs before starting eval.',\n                                start_delay_secs)\n      time.sleep(start_delay_secs)\n\n    latest_eval_result = None\n    evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,\n                                             self._train_spec.max_steps)\n\n    should_early_stop = False\n    while not should_early_stop:\n      if (latest_eval_result and\n          latest_eval_result.status == _EvalStatus.EVALUATED):\n        global_step = latest_eval_result.metrics.get(\n            tf.compat.v1.GraphKeys.GLOBAL_STEP)\n        if (global_step and self._train_spec.max_steps and\n            global_step >= self._train_spec.max_steps):\n          tf.compat.v1.logging.info(\n              'Exiting evaluation, global_step=%s >= train max_steps=%s',\n              global_step, self._train_spec.max_steps)\n          return\n\n      latest_eval_result, should_early_stop = self._execute_evaluator_once(\n          evaluator, self._continuous_eval_listener,\n          self._eval_spec.throttle_secs)\n\n  def _execute_evaluator_once(self, evaluator, continuous_eval_listener,\n                              throttle_secs):\n    \"\"\"Executes the `evaluator`.\"\"\"\n\n    _assert_eval_spec(self._eval_spec)\n\n    start = time.time()\n\n    eval_result = None\n    should_early_stop = False\n\n    if not continuous_eval_listener.before_eval():\n      tf.compat.v1.logging.info('Exiting evaluation, as requested by '\n                                '_ContinuousEvalListener.before_eval.')\n      should_early_stop = True\n      return (eval_result, should_early_stop)\n\n    # Final export signal: For any eval result with global_step >= train\n    # max_steps, the evaluator will send the final export signal. The next\n    # iteration of while loop will end the continuous eval as the stopping\n    # condition is satisfied (both checks use the same global_step value,\n    # i.e., no race condition)\n    eval_result, _ = evaluator.evaluate_and_export()\n\n    if not self._continuous_eval_listener.after_eval(eval_result):\n      tf.compat.v1.logging.info('Exiting evaluation, as requested by '\n                                '_ContinuousEvalListener.after_eval.')\n      should_early_stop = True\n      return (eval_result, should_early_stop)\n\n    # Throttle if necessary.\n    elapsed_time = time.time() - start\n    difference = throttle_secs - elapsed_time\n    if difference > 0:\n      tf.compat.v1.logging.info(\n          'Waiting %f secs before starting next eval run.', difference\n      )\n      time.sleep(difference)\n    elif (throttle_secs == 0 and eval_result.status != _EvalStatus.EVALUATED):\n      # Prints a user-actionable warning to avoid unnecessary load on evaluator.\n      tf.compat.v1.logging.warning(\n          'EvalSpec.throttle_secs is set as 0. This might overload the job '\n          'before finding (next) new checkpoint. Please consider to increase '\n          'it.')\n\n    return (eval_result, should_early_stop)\n\n  class _Evaluator(object):\n    \"\"\"A helper class to call `Estimator.evaluate` and export model.\"\"\"\n\n    def __init__(self, estimator, eval_spec, max_training_steps):\n      self._estimator = estimator\n\n      _assert_eval_spec(eval_spec)\n      self._eval_spec = eval_spec\n\n      self._is_final_export_triggered = False\n      self._previous_ckpt_path = None\n      self._last_warning_time = 0\n      self._max_training_steps = max_training_steps\n\n    @property\n    def is_final_export_triggered(self):\n      return self._is_final_export_triggered\n\n    def evaluate_and_export(self):\n      \"\"\"Evaluate and (maybe) export the current model.\n\n      Returns:\n        A tuple of `EvalResult` instance and the export results.\n\n      Raises:\n        RuntimeError: for any unexpected internal error.\n        TypeError: if evaluation result has wrong type.\n      \"\"\"\n      latest_ckpt_path = self._estimator.latest_checkpoint()\n      if not latest_ckpt_path:\n        self._log_err_msg('Estimator is not trained yet. Will start an '\n                          'evaluation when a checkpoint is ready.')\n        return _EvalResult(status=_EvalStatus.MISSING_CHECKPOINT), []\n\n      if latest_ckpt_path == self._previous_ckpt_path:\n        self._log_err_msg(\n            'No new checkpoint ready for evaluation. Skip the current '\n            'evaluation pass as evaluation results are expected to be same '\n            'for the same checkpoint.')\n        return _EvalResult(status=_EvalStatus.NO_NEW_CHECKPOINT), []\n\n      metrics = self._estimator.evaluate(\n          input_fn=self._eval_spec.input_fn,\n          steps=self._eval_spec.steps,\n          name=self._eval_spec.name,\n          checkpoint_path=latest_ckpt_path,\n          hooks=self._eval_spec.hooks)\n\n      # _EvalResult validates the metrics.\n      eval_result = _EvalResult(\n          status=_EvalStatus.EVALUATED,\n          metrics=metrics,\n          checkpoint_path=latest_ckpt_path)\n\n      is_the_final_export = (\n          eval_result.metrics[tf.compat.v1.GraphKeys.GLOBAL_STEP] >=\n          self._max_training_steps if self._max_training_steps else False)\n      export_results = self._export_eval_result(eval_result,\n                                                is_the_final_export)\n\n      if is_the_final_export:\n        tf.compat.v1.logging.debug(\n            'Calling exporter with the `is_the_final_export=True`.')\n        self._is_final_export_triggered = True\n\n      self._last_warning_time = 0\n      self._previous_ckpt_path = latest_ckpt_path\n      return eval_result, export_results\n\n    def _log_err_msg(self, message):\n      \"\"\"Prints warning `message` every 10 mins.\"\"\"\n      current_time = time.time()\n      if current_time - self._last_warning_time > 600:\n        tf.compat.v1.logging.warning(message)\n        self._last_warning_time = current_time\n\n    def _export_eval_result(self, eval_result, is_the_final_export):\n      \"\"\"Export `eval_result` according to exporters in `EvalSpec`.\"\"\"\n      export_dir_base = os.path.join(\n          tf.compat.as_str_any(self._estimator.model_dir),\n          tf.compat.as_str_any('export'))\n\n      export_results = []\n      for exporter in self._eval_spec.exporters:\n        export_results.append(\n            exporter.export(\n                estimator=self._estimator,\n                export_path=os.path.join(\n                    tf.compat.as_str_any(export_dir_base),\n                    tf.compat.as_str_any(exporter.name)),\n                checkpoint_path=eval_result.checkpoint_path,\n                eval_result=eval_result.metrics,\n                is_the_final_export=is_the_final_export))\n      return export_results\n\n\nclass _EvalStatus(object):\n  \"\"\"The status of an evaluation event.\n\n  For local training and evaluation, the status can only be `EVALUATED` as\n  `Estimator.train` always generates a new checkpoint.\n\n  For distributed training and evaluation, a separated evaluator keeps looking\n  for new checkpoint. So, multiple situations might occur:\n\n  - EVALUATED: A new checkpoint is found since last evaluation.\n      `Estimator.evaluate` will be invoked.\n  - MISSING_CHECKPOINT: No checkpoint can be found. Typically, this means\n      the trainer has not yet produced any checkpoint.\n  - NO_NEW_CHECKPOINT: No new checkpoint can be found since last evaluation.\n      Typically, this means the trainer has not yet produced any new checkpoint.\n  \"\"\"\n\n  EVALUATED = 'evaluated'\n  MISSING_CHECKPOINT = 'missing checkpoint'\n  NO_NEW_CHECKPOINT = 'no new checkpoint'\n\n\nclass _EvalResult(\n    collections.namedtuple('EvalResult',\n                           ['status', 'metrics', 'checkpoint_path'])):\n  \"\"\"_EvalResult holds the result of an evaluation event.\"\"\"\n\n  def __new__(cls, status, metrics=None, checkpoint_path=None):\n    \"\"\"Creates a validated `_EvalResult`.\n\n    Args:\n      status: See `_EvalStatus`.\n      metrics: The evaluation results returned by `Estimator.evaluate`. Only set\n        if status is `EVALUATED`.\n      checkpoint_path: The corresponding checkpoint path for the `metrics`. Only\n        set if status is `EVALUATED`.\n\n    Returns:\n      A validated `_EvalResult` object.\n\n    Raises:\n      ValueError: If validation fails.\n      TypeError: If any of the arguments is not the expected type.\n    \"\"\"\n\n    if status != _EvalStatus.EVALUATED:\n      if metrics:\n        raise ValueError(\n            'metrics must be `None` if status is not {}; got status {},'\n            ' metrics {}'.format(_EvalStatus.EVALUATED, status, metrics))\n      if checkpoint_path:\n        raise ValueError(\n            'checkpoint must be `None` if status is not {}; got status {}, '\n            'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status,\n                                        checkpoint_path))\n      return super(_EvalResult, cls).__new__(cls, status, metrics,\n                                             checkpoint_path)\n\n    # Now, evaluated case.\n    assert status == _EvalStatus.EVALUATED\n\n    # Validates metrics.\n    if not metrics:\n      raise ValueError(\n          'Internal error: `Estimator.evaluate` should never return empty '\n          'metrics.')\n    if not isinstance(metrics, dict):\n      raise TypeError(\n          '`Estimator.evaluate` should return dict. Given {}.'.format(\n              type(metrics)))\n    if tf.compat.v1.GraphKeys.GLOBAL_STEP not in metrics:\n      raise ValueError(\n          'Internal error: `Estimator.evaluate` result should have '\n          '`global_step` in result. Given {}'.format(metrics))\n\n    # Validates checkpoint_path.\n    if not checkpoint_path:\n      raise ValueError(\n          'Internal error: `checkpoint_path` should never be empty.')\n\n    return super(_EvalResult, cls).__new__(cls, status, metrics,\n                                           checkpoint_path)\n\n\nclass _ContinuousEvalListener(object):\n  \"\"\"Interface for listeners that take action before or after evaluation.\"\"\"\n\n  def before_eval(self):\n    \"\"\"Called before evaluation.\n\n    Returns:\n      `False` if you want to skip the current evaluation and early stop the\n      continuous evaluation; `True` otherwise.\n    \"\"\"\n    return True\n\n  def after_eval(self, eval_result):\n    \"\"\"Called after the evaluation is executed.\n\n    Args:\n      eval_result: An `_EvalResult` instance.\n\n    Returns:\n      False if you want to early stop continuous evaluation; `True` otherwise.\n    \"\"\"\n    del eval_result\n    return True\n\n\ndef _assert_eval_spec(eval_spec):\n  \"\"\"Raise error if `eval_spec` is not of the right type.\"\"\"\n  if not isinstance(eval_spec, EvalSpec):\n    raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`. '\n                    'Got: {}'.format(type(eval_spec)))\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/training_test.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for training.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport glob\nimport json\nimport os\nimport random\nimport shutil\nimport tempfile\nimport time\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.training import basic_session_run_hooks\nfrom tensorflow.python.training import server_lib\nfrom tensorflow_estimator.python.estimator import estimator as estimator_lib\nfrom tensorflow_estimator.python.estimator import exporter as exporter_lib\nfrom tensorflow_estimator.python.estimator import model_fn as model_fn_lib\nfrom tensorflow_estimator.python.estimator import run_config as run_config_lib\nfrom tensorflow_estimator.python.estimator import training\nfrom tensorflow_estimator.python.estimator.canned import dnn\nfrom tensorflow_estimator.python.estimator.canned import prediction_keys\nfrom tensorflow_estimator.python.estimator.export import export as export_lib\n\n_DEFAULT_EVAL_STEPS = 100\n_DEFAULT_EVAL_DELAY_SECS = 120\n_DEFAULT_EVAL_THROTTLE_SECS = 600\n_DELAY_SECS_PER_WORKER = 5\n_GLOBAL_STEP_KEY = tf.compat.v1.GraphKeys.GLOBAL_STEP\n_INVALID_INPUT_FN_MSG = '`input_fn` must be callable'\n_INVALID_HOOK_MSG = 'All hooks must be `SessionRunHook` instances'\n_INVALID_MAX_STEPS_MSG = 'Must specify max_steps > 0'\n_INVALID_STEPS_MSG = 'Must specify steps > 0'\n_INVALID_NAME_MSG = '`name` must be string'\n_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify start_delay_secs >= 0'\n_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'\n_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'\n_INVALID_SAVING_LISTENER_MSG = (\n    'All saving_listeners must be `CheckpointSaverListener` instances')\n_STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.'\n_INVALID_EXPORTER_MSG = '`exporters` must be an Exporter'\n_INVALID_EXPORTER_NAME_TYPE_MSG = 'An Exporter must have a string name'\n_DUPLICATE_EXPORTER_NAMES_MSG = '`exporters` must have unique names.'\n_NONE_EXPORTER_NAME_MSG = (\n    'An Exporter cannot have a name that is `None` or empty.')\n_INVALID_TRAIN_SPEC_MSG = '`train_spec` must have type `tf.estimator.TrainSpec`'\n_INVALID_EVAL_SPEC_MSG = '`eval_spec` must have type `tf.estimator.EvalSpec`'\n_EVAL_SPEC_OR_NONE_MSG = (\n    '`eval_spec` must be either `None` or have type `tf.estimator.EvalSpec`')\n_INVALID_EVAL_LISTENER_MSG = 'must have type `_ContinuousEvalListener`'\n_INVALID_CONFIG_FOR_STD_SERVER_MSG = 'Could not start server; .*TF_CONFIG'\n_INVALID_LOCAL_TASK_WITH_CLUSTER = '`task.type` in TF_CONFIG cannot be `local`'\n_INVALID_TASK_TYPE = '`estimator.config` must have task_type set.'\n_INPROPER_THROTTL_SECS = (\n    'EvalSpec.throttle_secs is set as 0.*Please consider to increase')\n\n# The message should NOT have 'local' word as part of it. As (?!word) is looking\n# ahead, so, the $ (ending) check is required; otherwise, it will match\n# partially and return successuful.\n_INVALID_TASK_TO_RUN = (\n    'Task type .* is not supported. Supported task types are ((?!local).)*$')\n_INVALID_EMPTY_EVAL_RESULT_ERR = (\n    'Internal error: `Estimator.evaluate` should never return empty metrics')\n_INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.'\n_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = (\n    'Internal error: `Estimator.evaluate` result should have `global_step`')\n_INVALID_EVAL_TASK_ID_ERR = (\n    'there can only be one `evaluator` task .*with task id 0')\n\n_TF_CONFIG_FOR_CHIEF = {\n    'cluster': {\n        run_config_lib.TaskType.CHIEF: ['host0:0'],\n        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']\n    },\n    'task': {\n        'type': run_config_lib.TaskType.CHIEF,\n        'index': 0\n    }\n}\n\n_TF_CONFIG_FOR_MASTER = {\n    'cluster': {\n        run_config_lib.TaskType.MASTER: ['host0:0'],\n        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']\n    },\n    'task': {\n        'type': run_config_lib.TaskType.MASTER,\n        'index': 0\n    }\n}\n\n_TF_CONFIG_FOR_WORKER = {\n    'cluster': {\n        run_config_lib.TaskType.CHIEF: ['host0:0'],\n        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']\n    },\n    'task': {\n        'type': run_config_lib.TaskType.WORKER,\n        'index': 1\n    }\n}\n\n_TF_CONFIG_FOR_PS = {\n    'cluster': {\n        run_config_lib.TaskType.CHIEF: ['host0:0'],\n        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']\n    },\n    'task': {\n        'type': run_config_lib.TaskType.PS,\n        'index': 1\n    }\n}\n\n_TF_CONFIG_FOR_EVALUATOR = {\n    'cluster': {\n        run_config_lib.TaskType.CHIEF: ['host0:0'],\n        run_config_lib.TaskType.PS: ['host1:1', 'host2:2'],\n        run_config_lib.TaskType.WORKER: ['host3:3', 'host4:4']\n    },\n    'task': {\n        'type': run_config_lib.TaskType.EVALUATOR,\n        'index': 0\n    }\n}\n\n_TF_CONFIG_FOR_GOOGLE = {'environment': 'google'}\n\n\nclass _FakeHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Fake implementation of `SessionRunHook`.\"\"\"\n\n\nclass _InvalidHook(object):\n  \"\"\"Invalid hook (not a subclass of `SessionRunHook`).\"\"\"\n\n\nclass _InvalidCheckpointSaverListener(object):\n  \"\"\"Invalid hook (not a subclass of `CheckpointSaverListener`).\"\"\"\n\n\ndef _create_exporter(name):\n\n  class FakeExporter(exporter_lib.Exporter):\n\n    def __init__(self, name):\n      self._name = name\n\n    @property\n    def name(self):\n      return self._name\n\n    def export(self, *args, **kwargs):\n      del args, kwargs\n\n  return FakeExporter(name=name)\n\n\ndef _create_run_config_with_cluster_spec(tf_config):\n  with tf.compat.v1.test.mock.patch.dict('os.environ',\n                                         {'TF_CONFIG': json.dumps(tf_config)}):\n    return run_config_lib.RunConfig()\n\n\nclass TrainSpecTest(tf.test.TestCase):\n  \"\"\"Tests TrainSpec.\"\"\"\n\n  def testRequiredArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all required arguments are set.\"\"\"\n    spec = training.TrainSpec(input_fn=lambda: 1)\n    self.assertEqual(1, spec.input_fn())\n    self.assertIsNone(spec.max_steps)\n    self.assertEqual(0, len(spec.hooks))\n\n  def testAllArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all arguments are set.\"\"\"\n    hooks = [_FakeHook()]\n    spec = training.TrainSpec(input_fn=lambda: 1, max_steps=2, hooks=hooks)\n    self.assertEqual(1, spec.input_fn())\n    self.assertEqual(2, spec.max_steps)\n    self.assertEqual(tuple(hooks), spec.hooks)\n\n  def testInvalidInputFn(self):\n    with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):\n      training.TrainSpec(input_fn='invalid')\n\n  def testInvalidMaxStep(self):\n    with self.assertRaisesRegexp(ValueError, _INVALID_MAX_STEPS_MSG):\n      training.TrainSpec(input_fn=lambda: 1, max_steps=0)\n\n  def testInvalidHook(self):\n    with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):\n      training.TrainSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])\n\n  def testInvalidSavingListener(self):\n    with self.assertRaisesRegexp(TypeError, _INVALID_SAVING_LISTENER_MSG):\n      training.TrainSpec(input_fn=lambda: 1,\n                         saving_listeners=[_InvalidCheckpointSaverListener()])\n\n\nclass EvalSpecTest(tf.test.TestCase):\n  \"\"\"Tests EvalSpec.\"\"\"\n\n  def testRequiredArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all required arguments are set.\"\"\"\n    spec = training.EvalSpec(input_fn=lambda: 1)\n    self.assertEqual(1, spec.input_fn())\n    self.assertEqual(_DEFAULT_EVAL_STEPS, spec.steps)\n    self.assertIsNone(spec.name)\n    self.assertEqual(0, len(spec.hooks))\n    self.assertEqual(0, len(spec.exporters))\n    self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.start_delay_secs)\n    self.assertEqual(_DEFAULT_EVAL_THROTTLE_SECS, spec.throttle_secs)\n\n  def testAllArgumentsSet(self):\n    \"\"\"Tests that no errors are raised when all arguments are set.\"\"\"\n    hooks = [_FakeHook()]\n    exporter = _create_exporter('a')\n\n    spec = training.EvalSpec(\n        input_fn=lambda: 1,\n        steps=2,\n        name='name',\n        hooks=hooks,\n        exporters=exporter,\n        start_delay_secs=3,\n        throttle_secs=4)\n    self.assertEqual(1, spec.input_fn())\n    self.assertEqual(2, spec.steps)\n    self.assertEqual('name', spec.name)\n    self.assertEqual(tuple(hooks), spec.hooks)\n    self.assertEqual((exporter,), spec.exporters)\n    self.assertEqual(3, spec.start_delay_secs)\n    self.assertEqual(4, spec.throttle_secs)\n\n  def testListOfExporters(self):\n    \"\"\"Tests that no errors are raised with multiple exporters.\"\"\"\n    exporters = [_create_exporter('a'), _create_exporter('b')]\n\n    spec = training.EvalSpec(input_fn=lambda: 1, exporters=exporters)\n    self.assertEqual(1, spec.input_fn())\n    self.assertEqual(tuple(exporters), spec.exporters)\n\n  def testInvalidInputFn(self):\n    with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG):\n      training.EvalSpec(input_fn='invalid')\n\n  def testInvalidMaxStep(self):\n    with self.assertRaisesRegexp(ValueError, _INVALID_STEPS_MSG):\n      training.EvalSpec(input_fn=lambda: 1, steps=0)\n\n  def testInvalidName(self):\n    with self.assertRaisesRegexp(TypeError, _INVALID_NAME_MSG):\n      training.EvalSpec(input_fn=lambda: 1, name=123)\n\n  def testInvalidHook(self):\n    with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):\n      training.EvalSpec(input_fn=lambda: 1, hooks=[_InvalidHook()])\n\n  def testInvalidDelaySecs(self):\n    with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_DELAY_SECS_MSG):\n      training.EvalSpec(input_fn=lambda: 1, start_delay_secs=-1)\n\n  def testInvalidThrottleSecs(self):\n    with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_THROTTLE_SECS_MSG):\n      training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1)\n\n  def testInvalidTypeOfListOfExporters(self):\n    with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG):\n      training.EvalSpec(\n          input_fn=lambda: 1, exporters=[_create_exporter('a'),\n                                         _FakeHook()])\n\n  def testInvalidTypeOfIndividualExporter(self):\n    with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG):\n      training.EvalSpec(input_fn=lambda: 1, exporters=_FakeHook())\n\n  def testInvalidTypeOfExporterName(self):\n    with self.assertRaisesRegexp(ValueError, _INVALID_EXPORTER_NAME_TYPE_MSG):\n      training.EvalSpec(\n          input_fn=lambda: 1, exporters=_create_exporter(name=123))\n\n  def testMultipleExportersWithTheSameName(self):\n    with self.assertRaisesRegexp(ValueError, _DUPLICATE_EXPORTER_NAMES_MSG):\n      training.EvalSpec(\n          input_fn=lambda: 1,\n          exporters=[_create_exporter('a'),\n                     _create_exporter('a')])\n\n  def testMultipleExportersAndOneWithoutAName(self):\n    with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG):\n      training.EvalSpec(\n          input_fn=lambda: 1,\n          exporters=[_create_exporter('a'),\n                     _create_exporter(None)])\n\n  def testSingleExporterWithoutAName(self):\n    with self.assertRaisesRegexp(ValueError, _NONE_EXPORTER_NAME_MSG):\n      training.EvalSpec(input_fn=lambda: 1, exporters=_create_exporter(None))\n\n\nclass TrainAndEvaluateTest(tf.test.TestCase):\n\n  def test_run_task(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    with tf.compat.v1.test.mock.patch.object(\n        training, '_TrainingExecutor') as mock_executor:\n      mock_executor_instance = tf.compat.v1.test.mock.Mock()\n      mock_executor.return_value = mock_executor_instance\n      training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)\n      mock_executor.assert_called_with(\n          estimator=mock_est,\n          train_spec=mock_train_spec,\n          eval_spec=mock_eval_spec)\n      self.assertTrue(mock_executor_instance.run.called)\n\n  def test_error_out_if_evaluator_task_id_is_non_zero(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n        },\n        'task': {\n            'type': run_config_lib.TaskType.EVALUATOR,\n            'index': 1\n        }\n    }\n\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = _create_run_config_with_cluster_spec(tf_config)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):\n      training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)\n\n  def test_invalid_estimator(self):\n    invalid_estimator = object()\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):\n      training.train_and_evaluate(invalid_estimator, mock_train_spec,\n                                  mock_eval_spec)\n\n  def test_fail_fast_if_invalid_eval_spec(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    invalid_eval_spec = object()\n\n    with tf.compat.v1.test.mock.patch.object(\n        training, '_TrainingExecutor') as mock_executor:\n      with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):\n        training.train_and_evaluate(mock_est, mock_train_spec,\n                                    invalid_eval_spec)\n\n      mock_executor.assert_not_called()\n\n\nclass TrainingExecutorConstructorTest(tf.test.TestCase):\n  \"\"\"Tests constructor of _TrainingExecutor.\"\"\"\n\n  def test_required_arguments_set(self):\n    estimator = estimator_lib.Estimator(model_fn=lambda features: features)\n    train_spec = training.TrainSpec(input_fn=lambda: 1)\n    eval_spec = training.EvalSpec(input_fn=lambda: 1)\n\n    executor = training._TrainingExecutor(estimator, train_spec, eval_spec)\n    self.assertEqual(estimator, executor.estimator)\n\n  def test_invalid_estimator(self):\n    invalid_estimator = object()\n    train_spec = training.TrainSpec(input_fn=lambda: 1)\n    eval_spec = training.EvalSpec(input_fn=lambda: 1)\n\n    with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):\n      training._TrainingExecutor(invalid_estimator, train_spec, eval_spec)\n\n  def test_invalid_train_spec(self):\n    estimator = estimator_lib.Estimator(model_fn=lambda features: features)\n    invalid_train_spec = object()\n    eval_spec = training.EvalSpec(input_fn=lambda: 1)\n\n    with self.assertRaisesRegexp(TypeError, _INVALID_TRAIN_SPEC_MSG):\n      training._TrainingExecutor(estimator, invalid_train_spec, eval_spec)\n\n  def test_invalid_eval_spec(self):\n    estimator = estimator_lib.Estimator(model_fn=lambda features: features)\n    train_spec = training.TrainSpec(input_fn=lambda: 1)\n    invalid_eval_spec = object()\n\n    with self.assertRaisesRegexp(TypeError, _EVAL_SPEC_OR_NONE_MSG):\n      training._TrainingExecutor(estimator, train_spec, invalid_eval_spec)\n\n  def test_eval_spec_none(self):\n    estimator = estimator_lib.Estimator(model_fn=lambda features: features)\n    train_spec = training.TrainSpec(input_fn=lambda: 1)\n    eval_spec = None\n\n    # Tests that no error is raised.\n    training._TrainingExecutor(estimator, train_spec, eval_spec)\n\n  def test_invalid_train_hooks(self):\n    estimator = estimator_lib.Estimator(model_fn=lambda features: features)\n    train_spec = training.TrainSpec(input_fn=lambda: 1)\n    eval_spec = training.EvalSpec(input_fn=lambda: 1)\n    invalid_train_hooks = [object()]\n\n    with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG):\n      training._TrainingExecutor(\n          estimator, train_spec, eval_spec, train_hooks=invalid_train_hooks)\n\n  def test_invalid_continuous_eval_listener(self):\n    estimator = estimator_lib.Estimator(model_fn=lambda features: features)\n    train_spec = training.TrainSpec(input_fn=lambda: 1)\n    eval_spec = training.EvalSpec(input_fn=lambda: 1)\n    invalid_continuous_eval_listener = object()\n\n    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_LISTENER_MSG):\n      training._TrainingExecutor(\n          estimator,\n          train_spec,\n          eval_spec,\n          continuous_eval_listener=invalid_continuous_eval_listener)\n\n\nclass _TrainingExecutorTrainingTest(object):\n  \"\"\"Tests training of _TrainingExecutor.\"\"\"\n\n  def __init__(self, run_config):\n    self._run_config = run_config\n\n  def _run_task(self, executor):\n    # We should not call executor.run as the test here is intended to test\n    # run_foo explicitly (foo is the task type).\n    return getattr(executor, 'run_' + self._run_config.task_type)()\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_train_with_train_spec(self, mock_server, unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = self._run_config\n    train_spec = training.TrainSpec(\n        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n    mock_server_instance = mock_server.return_value\n\n    executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)\n    self._run_task(executor)\n\n    mock_server.assert_called_with(\n        mock_est.config.cluster_spec,\n        job_name=mock_est.config.task_type,\n        task_index=mock_est.config.task_id,\n        config=tf.compat.v1.test.mock.ANY,\n        protocol=None,\n        start=False)\n\n    self.assertTrue(mock_server_instance.start.called)\n\n    mock_est.train.assert_called_with(\n        input_fn=train_spec.input_fn,\n        max_steps=train_spec.max_steps,\n        hooks=list(train_spec.hooks),\n        saving_listeners=tf.compat.v1.test.mock.ANY)\n    mock_est.evaluate.assert_not_called()\n    mock_est.export_saved_model.assert_not_called()\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_train_with_no_eval_spec(self, mock_server, unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = self._run_config\n    train_spec = training.TrainSpec(\n        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])\n    eval_spec = None\n    mock_server_instance = mock_server.return_value\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    self._run_task(executor)\n\n    mock_server.assert_called_with(\n        mock_est.config.cluster_spec,\n        job_name=mock_est.config.task_type,\n        task_index=mock_est.config.task_id,\n        config=tf.compat.v1.test.mock.ANY,\n        protocol=None,\n        start=False)\n\n    self.assertTrue(mock_server_instance.start.called)\n\n    mock_est.train.assert_called_with(\n        input_fn=train_spec.input_fn,\n        max_steps=train_spec.max_steps,\n        hooks=list(train_spec.hooks),\n        saving_listeners=tf.compat.v1.test.mock.ANY)\n    mock_est.evaluate.assert_not_called()\n    mock_est.export_saved_model.assert_not_called()\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = self._run_config\n    train_spec = training.TrainSpec(\n        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n    extra_hooks = [_FakeHook()]\n\n    executor = training._TrainingExecutor(\n        mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)\n    self._run_task(executor)\n\n    mock_est.train.assert_called_with(\n        input_fn=train_spec.input_fn,\n        max_steps=train_spec.max_steps,\n        hooks=list(train_spec.hooks) + extra_hooks,\n        saving_listeners=tf.compat.v1.test.mock.ANY)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = self._run_config\n    mock_train_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.TrainSpec, hooks=[])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n    tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}\n    with tf.compat.v1.test.mock.patch.dict('os.environ', tf_config):\n      self._run_task(executor)\n      mock_server.assert_not_called()\n\n  def test_fail_with_empty_cluster_spec(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = None\n    mock_est.config.master = 'grpc://...'\n    mock_est.config.task_type = 'worker'\n    mock_est.config.task_id = 2\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      self._run_task(\n          training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))\n\n  def test_fail_with_empty_master(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec(\n        {'worker': ['dummy', 'dummy1']})\n    mock_est.config.master = ''\n    mock_est.config.task_type = 'worker'\n    mock_est.config.task_id = 2\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      self._run_task(\n          training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_single_worker_node_with_empty_tf_master(self, mock_server,\n                                                   unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.TrainSpec, hooks=[])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    # Single node cluster.\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'worker': ['dummy']})\n    mock_est.config.master = ''\n    mock_est.config.task_type = 'worker'\n    mock_est.config.task_id = 2\n\n    self._run_task(\n        training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))\n    self.assertTrue(mock_est.train.called)\n    mock_server.assert_not_called()\n\n  def test_fail_with_empty_task_type(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'worker': ['dummy']})\n    mock_est.config.master = 'grpc://...'\n    mock_est.config.task_type = ''\n    mock_est.config.task_id = 2\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      self._run_task(\n          training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))\n\n  def test_fail_with_none_task_id(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'worker': ['dummy']})\n    mock_est.config.master = 'grpc://...'\n    mock_est.config.task_type = 'worker'\n    mock_est.config.task_id = None\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      self._run_task(\n          training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))\n\n\nclass TrainingExecutorRunWorkerTest(_TrainingExecutorTrainingTest,\n                                    tf.test.TestCase):\n  \"\"\"Tests run_worker of _TrainingExecutor.\"\"\"\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    _TrainingExecutorTrainingTest.__init__(\n        self,\n        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER))\n\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_delay_for_worker(self, _):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = self._run_config\n    mock_train_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.TrainSpec, hooks=[])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n\n    expected_secs = (self._run_config.task_id + 1) * _DELAY_SECS_PER_WORKER\n    with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:\n      mock_sleep.side_effect = lambda s: self.assertEqual(expected_secs, s)\n      self._run_task(executor)\n      self.assertTrue(mock_sleep.called)\n\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_delay_disabled_for_worker(self, _):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = self._run_config.replace(\n        experimental_max_worker_delay_secs=0)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.TrainSpec, hooks=[])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n\n    with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:\n      self._run_task(executor)\n      self.assertFalse(mock_sleep.called)\n\n\nclass TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest,\n                                   tf.test.TestCase):\n  \"\"\"Tests run_chief of _TrainingExecutor.\"\"\"\n\n  def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name\n    tf.test.TestCase.__init__(self, methodName)\n    _TrainingExecutorTrainingTest.__init__(\n        self,\n        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF))\n\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_no_delay_for_chief(self, _):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = self._run_config\n    mock_train_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.TrainSpec, hooks=[])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n\n    with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:\n      self._run_task(executor)\n      mock_sleep.assert_not_called()\n\n\nclass TrainingExecutorRunMasterTest(tf.test.TestCase):\n  \"\"\"Tests run_chief of _TrainingExecutor.\"\"\"\n\n  def setUp(self):\n    self._run_config = _create_run_config_with_cluster_spec(\n        _TF_CONFIG_FOR_MASTER)\n\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_no_delay_for_master(self, _):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.evaluate = lambda *args, **kw: {\n        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123\n    }\n    mock_est.config = self._run_config\n    mock_train_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.TrainSpec, max_steps=123, hooks=[])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.EvalSpec, exporters=[])\n\n    mock_train_spec.saving_listeners = tuple([])\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n\n    with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:\n      executor.run_master()\n      mock_sleep.assert_not_called()\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_train_with_train_spec(self, mock_server, unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.evaluate = lambda *args, **kw: {\n        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123\n    }\n    mock_est.config = self._run_config\n    train_spec = training.TrainSpec(\n        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.EvalSpec, exporters=[])\n    mock_server_instance = mock_server.return_value\n\n    executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)\n    executor.run_master()\n\n    mock_server.assert_called_with(\n        mock_est.config.cluster_spec,\n        job_name=mock_est.config.task_type,\n        task_index=mock_est.config.task_id,\n        config=tf.compat.v1.test.mock.ANY,\n        protocol=None,\n        start=False)\n\n    self.assertTrue(mock_server_instance.start.called)\n\n    mock_est.train.assert_called_with(\n        input_fn=train_spec.input_fn,\n        max_steps=train_spec.max_steps,\n        hooks=list(train_spec.hooks),\n        saving_listeners=tf.compat.v1.test.mock.ANY)\n    mock_est.export_saved_model.assert_not_called()\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_train_with_no_eval_spec_fails(self, mock_server, unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.evaluate = lambda *args, **kw: {\n        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123\n    }\n    mock_est.config = self._run_config\n    train_spec = training.TrainSpec(\n        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])\n    eval_spec = None\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):\n      executor.run_master()\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_train_with_train_hooks(self, mock_server, unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.evaluate = lambda *args, **kw: {\n        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123\n    }\n    mock_est.config = self._run_config\n    train_spec = training.TrainSpec(\n        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.EvalSpec, exporters=[])\n    extra_hooks = [_FakeHook()]\n\n    executor = training._TrainingExecutor(\n        mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)\n    executor.run_master()\n\n    mock_est.train.assert_called_with(\n        input_fn=train_spec.input_fn,\n        max_steps=train_spec.max_steps,\n        hooks=list(train_spec.hooks) + extra_hooks,\n        saving_listeners=tf.compat.v1.test.mock.ANY)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.evaluate = lambda *args, **kw: {\n        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123\n    }\n    mock_est.config = self._run_config\n    mock_train_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.TrainSpec, max_steps=123, hooks=[])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.EvalSpec, exporters=[])\n\n    mock_train_spec.saving_listeners = tuple([])\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n    tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}\n    with tf.compat.v1.test.mock.patch.dict('os.environ', tf_config):\n      executor.run_master()\n      mock_server.assert_not_called()\n\n  def test_fail_with_empty_cluster_spec(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = None\n    mock_est.config.master = 'grpc://...'\n    mock_est.config.task_type = 'master'\n    mock_est.config.task_id = 2\n\n    mock_train_spec.saving_listeners = tuple([])\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      training._TrainingExecutor(mock_est, mock_train_spec,\n                                 mock_eval_spec).run_master()\n\n  def test_fail_with_empty_master(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({\n        'master': ['dummy'],\n        'worker': ['dummy1']\n    })\n    mock_est.config.master = ''\n    mock_est.config.task_type = 'master'\n    mock_est.config.task_id = 0\n\n    mock_train_spec.saving_listeners = tuple([])\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      training._TrainingExecutor(mock_est, mock_train_spec,\n                                 mock_eval_spec).run_master()\n\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_single_master_node_with_empty_tf_master(self, mock_server,\n                                                   unused_mock_sleep):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.evaluate = lambda *args, **kw: {\n        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123\n    }\n\n    mock_train_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.TrainSpec, max_steps=123, hooks=[])\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.EvalSpec, exporters=[])\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'master': ['dummy']})\n    mock_est.config.master = ''\n    mock_est.config.task_type = 'master'\n    mock_est.config.task_id = 0\n\n    mock_train_spec.saving_listeners = tuple([])\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n    executor.run_master()\n\n    mock_server.assert_not_called()\n    self.assertTrue(mock_est.train.called)\n\n  def test_fail_with_empty_task_type(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'master': ['dummy']})\n    mock_est.config.master = 'grpc://...'\n    mock_est.config.task_type = ''\n    mock_est.config.task_id = 2\n\n    mock_train_spec.saving_listeners = tuple([])\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      training._TrainingExecutor(mock_est, mock_train_spec,\n                                 mock_eval_spec).run_master()\n\n  def test_fail_with_none_task_id(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'master': ['dummy']})\n    mock_est.config.master = 'grpc://...'\n    mock_est.config.task_type = 'master'\n    mock_est.config.task_id = None\n\n    mock_train_spec.saving_listeners = tuple([])\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      training._TrainingExecutor(mock_est, mock_train_spec,\n                                 mock_eval_spec).run_master()\n\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_run_master_triggers_evaluate_and_export(self, _):\n\n    def estimator_train(saving_listeners, *args, **kwargs):\n      #  There shalt be a saving_listener.  Estimator is going to call\n      # `after_save`.\n      del args, kwargs\n      saving_listeners[0].begin()\n      saving_listeners[0].after_save(session=None, global_step_value=0)\n      saving_listeners[0].after_save(session=None, global_step_value=10)\n\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train)\n    mock_est.latest_checkpoint.return_value = 'checkpoint_path/'\n    mock_est.config = self._run_config\n\n    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)\n    exporter.name = 'see_whether_export_is_called'\n\n    train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1, steps=2, exporters=exporter)\n    eval_result = {_GLOBAL_STEP_KEY: train_spec.max_steps}\n    mock_est.evaluate.return_value = eval_result\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    executor.run_master()\n\n    mock_est.evaluate.assert_called_with(\n        name=eval_spec.name,\n        input_fn=eval_spec.input_fn,\n        steps=eval_spec.steps,\n        checkpoint_path='checkpoint_path/',\n        hooks=eval_spec.hooks)\n    self.assertEqual(1, exporter.export.call_count)\n    exporter.export.assert_called_with(\n        estimator=mock_est,\n        export_path=os.path.join('path/', 'export', exporter.name),\n        checkpoint_path='checkpoint_path/',\n        eval_result=eval_result,\n        is_the_final_export=True)\n\n  @tf.compat.v1.test.mock.patch.object(basic_session_run_hooks,\n                                       'SecondOrStepTimer')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_run_master_throttle_eval(self, _, mock_timer_class):\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, model_dir='path/')\n\n    mock_timer = tf.compat.v1.test.mock.Mock()\n    mock_timer_class.return_value = mock_timer\n\n    def estimator_train(saving_listeners, *args, **kwargs):\n      del args, kwargs\n      saving_listeners[0].begin()\n\n      # Call four times.\n      mock_timer.should_trigger_for_step.return_value = True\n      saving_listeners[0].after_save(session=None, global_step_value=None)\n\n      mock_timer.should_trigger_for_step.return_value = True\n      saving_listeners[0].after_save(session=None, global_step_value=None)\n\n      mock_timer.should_trigger_for_step.return_value = False\n      saving_listeners[0].after_save(session=None, global_step_value=None)\n\n      mock_timer.should_trigger_for_step.return_value = True\n      saving_listeners[0].after_save(session=None, global_step_value=None)\n\n    mock_est.train = estimator_train\n    mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']\n    mock_est.config = self._run_config\n\n    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)\n    exporter.name = 'see_whether_export_is_called'\n\n    train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)\n\n    mock_est.evaluate.side_effect = [{\n        _GLOBAL_STEP_KEY: train_spec.max_steps // 2\n    }, {\n        _GLOBAL_STEP_KEY: train_spec.max_steps\n    }]\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    executor.run_master()\n\n    self.assertEqual(2, mock_est.evaluate.call_count)\n    self.assertEqual(2, exporter.export.call_count)\n\n    is_final_export_list = [\n        call[1]['is_the_final_export']\n        for call in exporter.export.call_args_list\n    ]\n    self.assertEqual([False, True], is_final_export_list)\n\n  @tf.compat.v1.test.mock.patch.object(basic_session_run_hooks,\n                                       'SecondOrStepTimer')\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_run_master_throttle_eval_which_skips_final_ckpt(\n      self, _, mock_timer_class):\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, model_dir='path/')\n\n    mock_timer = tf.compat.v1.test.mock.Mock()\n    mock_timer_class.return_value = mock_timer\n\n    def estimator_train(saving_listeners, *args, **kwargs):\n      del args, kwargs\n      saving_listeners[0].begin()\n\n      # Call tree times (one for first saving).\n      mock_timer.should_trigger_for_step.return_value = True\n      saving_listeners[0].after_save(session=None, global_step_value=0)\n\n      mock_timer.should_trigger_for_step.return_value = True\n      saving_listeners[0].after_save(session=None, global_step_value=125)\n\n      mock_timer.should_trigger_for_step.return_value = False\n      saving_listeners[0].after_save(session=None, global_step_value=250)\n\n      # At the end evaluate should be called even if throttle secs prevents it.\n      mock_timer.should_trigger_for_step.return_value = False\n      saving_listeners[0].end(session=None, global_step_value=300)\n\n    mock_est.train = estimator_train\n    mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']\n    mock_est.config = self._run_config\n\n    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)\n    exporter.name = 'see_whether_export_is_called'\n\n    train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)\n\n    mock_est.evaluate.side_effect = [{\n        _GLOBAL_STEP_KEY: train_spec.max_steps // 2\n    }, {\n        _GLOBAL_STEP_KEY: train_spec.max_steps\n    }]\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    executor.run_master()\n\n    self.assertEqual(2, mock_est.evaluate.call_count)\n    self.assertEqual(2, exporter.export.call_count)\n\n    is_final_export_list = [\n        call[1]['is_the_final_export']\n        for call in exporter.export.call_args_list\n    ]\n    self.assertEqual([False, True], is_final_export_list)\n\n\nclass TrainingExecutorRunEvaluatorTest(tf.test.TestCase):\n  \"\"\"Tests run_evaluator of _TrainingExecutor.\"\"\"\n\n  def _set_up_mock_est_to_train_and_evaluate_once(self, mock_est,\n                                                  mock_train_spec):\n    \"\"\"Sets global step in eval result to end the while True eval loop.\"\"\"\n    training_max_step = 200\n    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step}\n    mock_train_spec.max_steps = training_max_step\n\n  def test_evaluate_with_evaluate_spec(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.latest_checkpoint.return_value = 'latest_it_is'\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1,\n        steps=2,\n        hooks=[_FakeHook()],\n        name='cont_eval',\n        start_delay_secs=0,\n        throttle_secs=0)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)\n    executor.run_evaluator()\n\n    mock_est.evaluate.assert_called_with(\n        name='cont_eval',\n        input_fn=eval_spec.input_fn,\n        steps=eval_spec.steps,\n        checkpoint_path='latest_it_is',\n        hooks=eval_spec.hooks)\n    self.assertFalse(mock_est.train.called)\n\n  def test_evaluate_with_no_eval_spec_fails(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.latest_checkpoint.return_value = 'latest_it_is'\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)\n\n    eval_spec = None\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)\n\n    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):\n      executor.run_evaluator()\n\n  def test_evaluate_with_train_hooks(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.latest_checkpoint.return_value = 'latest_it_is'\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1,\n        steps=2,\n        hooks=[_FakeHook()],\n        name='cont_eval',\n        start_delay_secs=0,\n        throttle_secs=0)\n\n    # The train_hooks will not be called during eval.\n    mock_hook = tf.compat.v1.test.mock.Mock(\n        spec=tf.compat.v1.train.SessionRunHook)\n    executor = training._TrainingExecutor(\n        mock_est, mock_train_spec, eval_spec, train_hooks=[mock_hook])\n    executor.run_evaluator()\n\n    mock_hook.begin.assert_not_called()\n\n  def test_evaluate_multiple_times(self):\n    training_max_step = 200\n\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.model_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())\n    mock_est.evaluate.side_effect = [{\n        _GLOBAL_STEP_KEY: training_max_step // 2\n    }, {\n        _GLOBAL_STEP_KEY: training_max_step\n    }]\n    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']\n\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_train_spec.max_steps = training_max_step\n\n    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)\n    exporter.name = 'see_how_many_times_export_is_called'\n\n    mock_est.times_export_was_called = 0\n    mock_est.times_final_export_was_called = 0\n\n    def export(estimator, export_path, checkpoint_path, eval_result,\n               is_the_final_export):\n      del export_path, checkpoint_path, eval_result\n      estimator.times_export_was_called += 1\n      # final_export is happened at the end.\n      self.assertEqual(0, estimator.times_final_export_was_called)\n      if is_the_final_export:\n        estimator.times_final_export_was_called += 1\n\n    exporter.export = export\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1,\n        start_delay_secs=0,\n        throttle_secs=0,\n        exporters=exporter)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)\n    executor.run_evaluator()\n\n    self.assertEqual(2, mock_est.evaluate.call_count)\n    self.assertEqual(2, mock_est.times_export_was_called)\n    self.assertEqual(1, mock_est.times_final_export_was_called)\n\n  def test_evaluate_listener_before_eval(self):\n    training_max_step = 200\n\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.model_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())\n    # Without early stopping, this eval will be run twice.\n    mock_est.evaluate.side_effect = [{\n        _GLOBAL_STEP_KEY: training_max_step // 2\n    }, {\n        _GLOBAL_STEP_KEY: training_max_step\n    }]\n    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']\n\n    mock_train_spec = tf.compat.v1.test.mock.Mock(\n        spec=training.TrainSpec, hooks=[])\n    mock_train_spec.max_steps = training_max_step\n\n    class _Listener(training._ContinuousEvalListener):\n\n      def __init__(self):\n        self.call_count = 0\n\n      def before_eval(self):\n        self.call_count += 1\n        return self.call_count == 1\n\n    listener = _Listener()\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)\n\n    training._TrainingExecutor(\n        mock_est, mock_train_spec, eval_spec,\n        continuous_eval_listener=listener).run_evaluator()\n\n    # Before_eval returns False during the second time, so, evaluate will be\n    # called once.\n    self.assertEqual(1, mock_est.evaluate.call_count)\n    self.assertEqual(2, listener.call_count)\n\n  def test_evaluate_listener_after_eval(self):\n    training_max_step = 200\n\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.model_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())\n    # Without early stopping, this eval will be run twice.\n    expected_eval_metrics = [{\n        _GLOBAL_STEP_KEY: training_max_step // 2\n    }, {\n        _GLOBAL_STEP_KEY: training_max_step\n    }]\n    mock_est.evaluate.side_effect = expected_eval_metrics\n    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']\n\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_train_spec.max_steps = training_max_step\n\n    class _Listener(training._ContinuousEvalListener):\n\n      def __init__(self):\n        self.call_count = 0\n\n      def after_eval(self, eval_result):\n        self.call_count += 1\n        self.eval_result = eval_result\n        return False\n\n    listener = _Listener()\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)\n\n    training._TrainingExecutor(\n        mock_est, mock_train_spec, eval_spec,\n        continuous_eval_listener=listener).run_evaluator()\n\n    # after_eval returns False during the first time, so, evaluate will be\n    # called once.\n    self.assertEqual(1, mock_est.evaluate.call_count)\n    self.assertEqual(1, listener.call_count)\n    self.assertAllEqual(expected_eval_metrics[0], listener.eval_result.metrics)\n    self.assertEqual('path_1', listener.eval_result.checkpoint_path)\n\n  def test_final_export_is_true_in_the_end(self):\n    training_max_step = 200\n\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.model_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())\n    mock_est.evaluate.side_effect = [{\n        _GLOBAL_STEP_KEY: training_max_step // 2\n    }, {\n        _GLOBAL_STEP_KEY: training_max_step\n    }]\n    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']\n\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_train_spec.max_steps = training_max_step\n\n    mock_est.times_export_fn_was_called = 0\n    mock_est.times_the_final_export_was_true = 0\n\n    def export(estimator, export_path, checkpoint_path, eval_result,\n               is_the_final_export):\n      del export_path, checkpoint_path, eval_result\n      estimator.times_export_fn_was_called += 1\n      if is_the_final_export:\n        estimator.times_the_final_export_was_true += 1\n\n    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)\n    exporter.name = 'see_how_many_times_export_is_called'\n    exporter.export = export\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1,\n        start_delay_secs=0,\n        throttle_secs=0,\n        exporters=exporter)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)\n    executor.run_evaluator()\n\n    self.assertEqual(2, mock_est.evaluate.call_count)\n    self.assertEqual(2, mock_est.times_export_fn_was_called)\n    self.assertEqual(1, mock_est.times_the_final_export_was_true)\n\n  def test_skip_evaluation_due_to_ckpt(self):\n    training_max_step = 200\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.evaluate.side_effect = [{\n        _GLOBAL_STEP_KEY: training_max_step // 2\n    }, {\n        _GLOBAL_STEP_KEY: training_max_step\n    }]\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_train_spec.max_steps = training_max_step\n\n    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)\n\n    # First two items are invalid, next two items are same.\n    mock_est.latest_checkpoint.side_effect = [\n        None, '', 'same', 'same', 'path_2'\n    ]\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=2)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)\n    with tf.compat.v1.test.mock.patch.object(tf.compat.v1.logging, 'warning') as mock_log:\n      executor.run_evaluator()\n\n    # Three checkpoint paths are invalid.\n    self.assertEqual(5, mock_est.latest_checkpoint.call_count)\n    self.assertEqual(2, mock_est.evaluate.call_count)\n\n    # Two warning logs are expected (last warning time is reset after a\n    # successuful evaluation)\n    self.assertEqual(2, mock_log.call_count)\n\n  def test_warning_if_throttle_secs_is_zero(self):\n    training_max_step = 200\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.evaluate.side_effect = [{_GLOBAL_STEP_KEY: training_max_step}]\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_train_spec.max_steps = training_max_step\n\n    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)\n\n    # We need to make the first one invalid, so it will check the\n    # throttle_secs=0.\n    mock_est.latest_checkpoint.side_effect = [None, 'path']\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)\n    with tf.compat.v1.test.mock.patch.object(tf.compat.v1.logging, 'warning') as mock_log:\n      executor.run_evaluator()\n\n    # First ckpt is invalid.\n    self.assertEqual(2, mock_est.latest_checkpoint.call_count)\n    self.assertEqual(1, mock_est.evaluate.call_count)\n\n    self.assertRegexpMatches(str(mock_log.call_args), _INPROPER_THROTTL_SECS)\n\n  def test_continuous_eval_listener_eval_result(self):\n    training_max_step = 200\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    expected_eval_metrics = [{\n        _GLOBAL_STEP_KEY: training_max_step // 2\n    }, {\n        _GLOBAL_STEP_KEY: training_max_step\n    }]\n    mock_est.evaluate.side_effect = expected_eval_metrics\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_train_spec.max_steps = training_max_step\n\n    class _Listener(training._ContinuousEvalListener):\n\n      def __init__(self):\n        self.eval_results = []\n\n      def after_eval(self, eval_result):\n        self.eval_results.append(eval_result)\n        return True\n\n    continuous_eval_listener = _Listener()\n\n    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)\n\n    # First two items are invalid, next two items are same.\n    mock_est.latest_checkpoint.side_effect = [\n        None, '', 'same', 'same', 'path_2'\n    ]\n    expected_eval_results = [\n        training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT),\n        training._EvalResult(training._EvalStatus.MISSING_CHECKPOINT),\n        training._EvalResult(\n            training._EvalStatus.EVALUATED,\n            metrics=expected_eval_metrics[0],\n            checkpoint_path='same'),\n        training._EvalResult(training._EvalStatus.NO_NEW_CHECKPOINT),\n        training._EvalResult(\n            training._EvalStatus.EVALUATED,\n            metrics=expected_eval_metrics[1],\n            checkpoint_path='path_2'),\n    ]\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)\n\n    executor = training._TrainingExecutor(\n        mock_est,\n        mock_train_spec,\n        eval_spec,\n        continuous_eval_listener=continuous_eval_listener)\n    executor.run_evaluator()\n\n    # Three checkpoint paths are invalid.\n    self.assertEqual(5, mock_est.latest_checkpoint.call_count)\n    self.assertEqual(2, mock_est.evaluate.call_count)\n\n    self.assertEqual(5, len(continuous_eval_listener.eval_results))\n    for i, result in enumerate(continuous_eval_listener.eval_results):\n      self.assertEqual(expected_eval_results[i].status, result.status)\n      self.assertAllEqual(expected_eval_results[i].metrics, result.metrics)\n      self.assertEqual(expected_eval_results[i].checkpoint_path,\n                       result.checkpoint_path)\n\n  def test_sleep_start_delay_secs(self):\n    training_max_step = 200\n    start_delay_secs = 123\n\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step}\n    mock_est.model_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_train_spec.max_steps = training_max_step\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1,\n        steps=2,\n        hooks=[_FakeHook()],\n        name='cont_eval',\n        start_delay_secs=start_delay_secs,\n        throttle_secs=0)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)\n    with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:\n      executor.run_evaluator()\n      mock_sleep.assert_called_with(start_delay_secs)\n      self.assertTrue(mock_est.evaluate.called)\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  @tf.compat.v1.test.mock.patch.object(time, 'sleep')\n  def test_throttle_secs(self, mock_sleep, mock_time):\n    throttle_secs = 123\n    operation_secs = 12\n\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=throttle_secs)\n\n    mock_time.side_effect = [921, 921 + operation_secs]\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)\n    # Disable logging as it calls time.time also.\n    with tf.compat.v1.test.mock.patch.object(tf.compat.v1.logging, 'info'):\n      executor.run_evaluator()\n    mock_sleep.assert_called_with(throttle_secs - operation_secs)\n    self.assertTrue(mock_est.evaluate.called)\n\n  def test_that_export_is_called(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)\n\n    def export(estimator, *args, **kwargs):\n      del args, kwargs\n      estimator.export_was_called = True\n\n    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)\n    exporter.name = 'see_whether_export_is_called'\n    exporter.export = export\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: 1,\n        steps=2,\n        start_delay_secs=0,\n        throttle_secs=0,\n        exporters=exporter)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)\n    executor.run_evaluator()\n\n    # Verify that export was called on the right estimator.\n    self.assertTrue(mock_est.export_was_called)\n\n  def test_errors_out_if_evaluate_returns_empty_dict(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    train_spec = training.TrainSpec(input_fn=lambda: 1)\n    eval_spec = training.EvalSpec(\n        input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0)\n    mock_est.evaluate.return_value = {}\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):\n      executor.run_evaluator()\n\n  def test_errors_out_if_evaluate_returns_non_dict(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    train_spec = training.TrainSpec(input_fn=lambda: 1)\n    eval_spec = training.EvalSpec(\n        input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0)\n    mock_est.evaluate.return_value = 123\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):\n      executor.run_evaluator()\n\n  def test_errors_out_if_evaluate_returns_dict_without_global_step(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    train_spec = training.TrainSpec(input_fn=lambda: 1)\n    eval_spec = training.EvalSpec(\n        input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0)\n    mock_est.evaluate.return_value = {'loss': 123}\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    with self.assertRaisesRegexp(ValueError,\n                                 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):\n      executor.run_evaluator()\n\n\nclass TrainingExecutorRunPsTest(tf.test.TestCase):\n  \"\"\"Tests run_ps of _TrainingExecutor.\"\"\"\n\n  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')\n  def test_std_server(self, mock_server):\n    mock_server_instance = tf.compat.v1.test.mock.Mock()\n    mock_server.return_value = mock_server_instance\n\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = _create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n    executor.run_ps()\n\n    mock_server.assert_called_with(\n        mock_est.config.cluster_spec,\n        job_name=mock_est.config.task_type,\n        task_index=mock_est.config.task_id,\n        config=tf.compat.v1.test.mock.ANY,\n        protocol=None,\n        start=False)\n\n    self.assertTrue(mock_server_instance.start.called)\n    self.assertTrue(mock_server_instance.join.called)\n\n  def test_fail_with_empty_cluster_spec(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = None\n    mock_est.config.master = 'grpc://...'\n    mock_est.config.task_type = 'ps'\n    mock_est.config.task_id = 2\n\n    mock_train_spec.saving_listeners = tuple([])\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      training._TrainingExecutor(mock_est, mock_train_spec,\n                                 mock_eval_spec).run_ps()\n\n  def test_fail_with_empty_master(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'ps': ['dummy']})\n    mock_est.config.master = ''\n    mock_est.config.task_type = 'ps'\n    mock_est.config.task_id = 2\n\n    mock_train_spec.saving_listeners = tuple([])\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      training._TrainingExecutor(mock_est, mock_train_spec,\n                                 mock_eval_spec).run_ps()\n\n  def test_fail_with_empty_task_type(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'ps': ['dummy']})\n    mock_est.config.master = 'grpc://...'\n    mock_est.config.task_type = ''\n    mock_est.config.task_id = 2\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      training._TrainingExecutor(mock_est, mock_train_spec,\n                                 mock_eval_spec).run_ps()\n\n  def test_fail_with_none_task_id(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.PropertyMock(\n        spec=run_config_lib.RunConfig)\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'ps': ['dummy']})\n    mock_est.config.master = 'grpc://...'\n    mock_est.config.task_type = 'ps'\n    mock_est.config.task_id = None\n\n    with self.assertRaisesRegexp(RuntimeError,\n                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):\n      training._TrainingExecutor(mock_est, mock_train_spec,\n                                 mock_eval_spec).run_ps()\n\n\nclass StopAtSecsHookTest(tf.test.TestCase):\n  \"\"\"Tests StopAtSecsHook.\"\"\"\n\n  @tf.compat.v1.test.mock.patch.object(time, 'time')\n  def test_stops_after_time(self, mock_time):\n    mock_time.return_value = 1484695987.209386\n    hook = training._StopAtSecsHook(1000)\n    with tf.Graph().as_default():\n      no_op = tf.no_op()\n      # some time passed before training starts\n      mock_time.return_value += 250\n      with tf.compat.v1.train.MonitoredSession(hooks=[hook]) as sess:\n        self.assertFalse(sess.should_stop())\n        sess.run(no_op)\n        self.assertFalse(sess.should_stop())\n        mock_time.return_value += 500\n        sess.run(no_op)\n        self.assertFalse(sess.should_stop())\n        mock_time.return_value += 400\n        sess.run(no_op)\n        self.assertFalse(sess.should_stop())\n        mock_time.return_value += 200\n        sess.run(no_op)\n        self.assertTrue(sess.should_stop())\n\n\nclass TrainingExecutorRunLocalTest(tf.test.TestCase):\n  \"\"\"Tests run_local of _TrainingExecutor.\"\"\"\n\n  def _model_fn(self, features, labels, mode):\n    del labels\n    with tf.control_dependencies([features]):\n      train_op = tf.compat.v1.assign_add(tf.compat.v1.train.get_global_step(),\n                                         1)\n    return model_fn_lib.EstimatorSpec(\n        mode,\n        loss=tf.constant(0.),\n        train_op=train_op,\n        predictions=tf.constant([[10.]]),\n        eval_metric_ops={\n            'mean_of_features': tf.compat.v1.metrics.mean(features)\n        })\n\n  def _input_fn(self, repeat=True):\n    ds = tf.compat.v1.data.Dataset.from_tensors([1])\n    if repeat:\n      return ds.repeat()\n    return ds\n\n  def unique_checkpoint_every_time_fn(self):\n    return 'checkpoint_path_%s/' % random.random()\n\n  def test_runs_evaluate_with_every_new_checkpoint(self):\n    est = estimator_lib.Estimator(\n        model_fn=self._model_fn,\n        config=run_config_lib.RunConfig(save_checkpoints_steps=10))\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n\n    mock_est.times_export_was_called = 0\n    mock_est.times_final_export_was_called = 0\n\n    def export(estimator, export_path, checkpoint_path, eval_result,\n               is_the_final_export):\n      del export_path, checkpoint_path, eval_result\n      estimator.times_export_was_called += 1\n      # final_export is happened at the end.\n      self.assertEqual(0, estimator.times_final_export_was_called)\n      if is_the_final_export:\n        estimator.times_final_export_was_called += 1\n\n    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)\n    exporter.name = 'see_how_many_times_export_is_called'\n    exporter.export = export\n\n    train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=22)\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False),\n        throttle_secs=0,\n        exporters=exporter)\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    executor.run_local()\n\n    self.assertEqual(1, mock_est.train.call_count)\n    self.assertEqual(3, mock_est.evaluate.call_count)\n    self.assertEqual(3, mock_est.times_export_was_called)\n    self.assertEqual(1, mock_est.times_final_export_was_called)\n\n  def test_runs_with_eval_listener_before_eval(self):\n    est = estimator_lib.Estimator(\n        model_fn=self._model_fn,\n        config=run_config_lib.RunConfig(save_checkpoints_steps=10))\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n    mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn\n\n    train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12)\n    eval_spec = training.EvalSpec(input_fn=lambda: self._input_fn(repeat=False))\n    mock_est.evaluate.side_effect = [{_GLOBAL_STEP_KEY: train_spec.max_steps}]\n\n    class _Listener(training._ContinuousEvalListener):\n\n      def __init__(self):\n        self.call_count = 0\n\n      def before_eval(self):\n        self.call_count += 1\n        return False  # Will stop the run_local before first eval.\n\n    listener = _Listener()\n\n    executor = training._TrainingExecutor(\n        mock_est, train_spec, eval_spec, continuous_eval_listener=listener)\n    executor.run_local()\n\n    self.assertEqual(1, mock_est.train.call_count)\n    self.assertEqual(0, mock_est.evaluate.call_count)\n\n  def test_runs_with_eval_listener_after_eval(self):\n    est = estimator_lib.Estimator(\n        model_fn=self._model_fn,\n        config=run_config_lib.RunConfig(save_checkpoints_steps=10))\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n\n    train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=3000)\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)\n\n    class _Listener(training._ContinuousEvalListener):\n\n      def __init__(self):\n        self.call_count = 0\n\n      def after_eval(self, eval_result):\n        self.call_count += 1\n        return False  # Will stop the run_local after first eval.\n\n    listener = _Listener()\n\n    executor = training._TrainingExecutor(\n        mock_est, train_spec, eval_spec, continuous_eval_listener=listener)\n    metrics, _ = executor.run_local()  # pylint: disable=assignment-from-no-return\n\n    self.assertEqual(1, mock_est.train.call_count)\n    self.assertEqual(1, mock_est.evaluate.call_count)\n    self.assertEqual(1, listener.call_count)\n    # Should be less than max_steps since listener did early stopping.\n    self.assertLess(metrics[_GLOBAL_STEP_KEY], train_spec.max_steps)\n\n  def test_handles_no_new_checkpoint_found(self):\n    est = estimator_lib.Estimator(\n        model_fn=self._model_fn,\n        # disable saving checkpoint\n        config=run_config_lib.RunConfig(\n            save_checkpoints_steps=None, save_checkpoints_secs=None))\n    train_spec = training.TrainSpec(\n        input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()])\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False),\n        hooks=[_FakeHook()],\n        throttle_secs=100)\n\n    executor = training._TrainingExecutor(est, train_spec, eval_spec)\n    with self.assertRaisesRegexp(ValueError,\n                                 'There should be a CheckpointSaverHook'):\n      executor.run_local()\n\n  def test_final_export_is_true_in_the_end(self):\n    est = estimator_lib.Estimator(\n        model_fn=self._model_fn,\n        config=run_config_lib.RunConfig(save_checkpoints_steps=10))\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n\n    mock_est.times_export_fn_was_called = 0\n    mock_est.times_the_final_export_was_true = 0\n\n    def export(estimator, export_path, checkpoint_path, eval_result,\n               is_the_final_export):\n      del export_path, checkpoint_path, eval_result\n      estimator.times_export_fn_was_called += 1\n      if is_the_final_export:\n        estimator.times_the_final_export_was_true += 1\n\n    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)\n    exporter.name = 'see_how_many_times_export_is_called'\n    exporter.export = export\n\n    train_spec = training.TrainSpec(\n        input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()])\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False),\n        throttle_secs=0,\n        exporters=exporter)\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    executor.run_local()\n\n    self.assertEqual(1, mock_est.train.call_count)\n    self.assertEqual(2, mock_est.evaluate.call_count)\n    self.assertEqual(2, mock_est.times_export_fn_was_called)\n    self.assertEqual(1, mock_est.times_the_final_export_was_true)\n\n  def test_train_and_evaluate_args(self):\n    est = estimator_lib.Estimator(model_fn=self._model_fn)\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n    train_spec = training.TrainSpec(\n        input_fn=self._input_fn, max_steps=300, hooks=[_FakeHook()])\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False),\n        steps=2,\n        hooks=[_FakeHook()],\n        name='local_eval')\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    executor.run_local()\n\n    mock_est.evaluate.assert_called_with(\n        name=eval_spec.name,\n        input_fn=eval_spec.input_fn,\n        steps=eval_spec.steps,\n        checkpoint_path=est.latest_checkpoint(),\n        hooks=eval_spec.hooks)\n\n    train_args = mock_est.train.call_args[1]\n    self.assertEqual(list(train_spec.hooks), list(train_args['hooks']))\n    self.assertEqual(train_spec.input_fn, train_args['input_fn'])\n    self.assertEqual(train_spec.max_steps, train_args['max_steps'])\n\n  def test_train_with_no_eval_spec_fails(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    train_spec = training.TrainSpec(\n        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])\n    eval_spec = None\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n\n    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):\n      executor.run_local()\n\n  def test_train_hooks(self):\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, model_dir='path/')\n    mock_est.latest_checkpoint.return_value = 'checkpoint_path/'\n    train_spec = training.TrainSpec(\n        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])\n    eval_spec = training.EvalSpec(input_fn=lambda: 1, steps=2)\n    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}\n    extra_hooks = [_FakeHook()]\n\n    executor = training._TrainingExecutor(\n        mock_est, train_spec, eval_spec, train_hooks=extra_hooks)\n    executor.run_local()\n\n    train_args = mock_est.train.call_args[1]\n    self.assertEqual(\n        list(train_spec.hooks) + extra_hooks, [\n            h for h in train_args['hooks']\n            if not isinstance(h, training._StopAtSecsHook)\n        ])\n\n  def test_that_export_is_called_with_run_local(self):\n    est = estimator_lib.Estimator(model_fn=self._model_fn)\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n    train_spec = training.TrainSpec(input_fn=self._input_fn, max_steps=12)\n    mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: train_spec.max_steps}\n\n    def export(estimator, *args, **kwargs):\n      del args, kwargs\n      estimator.export_was_called = True\n      return 'path_to_export'\n\n    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)\n    exporter.name = 'see_whether_export_is_called'\n    exporter.export = export\n\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False),\n        steps=2,\n        start_delay_secs=0,\n        throttle_secs=213,\n        exporters=exporter)\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    # pylint: disable=assignment-from-no-return\n    _, export_results = executor.run_local()\n    # pylint: enable=assignment-from-no-return\n\n    self.assertTrue(mock_est.export_was_called)\n    self.assertEqual(export_results, ['path_to_export'])\n\n  def test_errors_out_if_evaluate_returns_empty_dict(self):\n    est = estimator_lib.Estimator(\n        model_fn=self._model_fn,\n        config=run_config_lib.RunConfig(save_checkpoints_steps=2))\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n    train_spec = training.TrainSpec(input_fn=self._input_fn)\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)\n    mock_est.evaluate.return_value = {}\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):\n      executor.run_local()\n\n  def test_errors_out_if_evaluate_returns_non_dict(self):\n    est = estimator_lib.Estimator(\n        model_fn=self._model_fn,\n        config=run_config_lib.RunConfig(save_checkpoints_steps=2))\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n    train_spec = training.TrainSpec(input_fn=self._input_fn)\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)\n    mock_est.evaluate.return_value = 123\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):\n      executor.run_local()\n\n  def test_errors_out_if_evaluate_returns_dict_without_global_step(self):\n    est = estimator_lib.Estimator(\n        model_fn=self._model_fn,\n        config=run_config_lib.RunConfig(save_checkpoints_steps=2))\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n    train_spec = training.TrainSpec(input_fn=self._input_fn)\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)\n    mock_est.evaluate.return_value = {'loss': 123}\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    with self.assertRaisesRegexp(ValueError,\n                                 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):\n      executor.run_local()\n\n  def test_train_and_evaluate_return_metrics(self):\n    est = estimator_lib.Estimator(model_fn=self._model_fn)\n    mock_est = tf.compat.v1.test.mock.Mock(\n        spec=estimator_lib.Estimator, wraps=est)\n    train_spec = training.TrainSpec(\n        input_fn=self._input_fn, max_steps=12, hooks=[_FakeHook()])\n    eval_spec = training.EvalSpec(\n        input_fn=lambda: self._input_fn(repeat=False),\n        steps=2,\n        hooks=[_FakeHook()],\n        name='local_eval')\n\n    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)\n    # pylint: disable=assignment-from-no-return\n    metrics, _ = executor.run_local()\n    # pylint: enable=assignment-from-no-return\n    self.assertEqual(metrics['global_step'], 12)\n\n\nclass TrainAndEvaluateRunTest(tf.test.TestCase):\n\n  def _test_run_task_and_executor(self, run_config):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = run_config\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n\n    executor.call_task = {}\n\n    def task_fn(name):\n\n      def _fn():\n        executor.call_task[name] = 1\n\n      return _fn\n\n    executor.run_chief = task_fn('chief')\n    executor.run_master = task_fn('master')\n    executor.run_ps = task_fn('ps')\n    executor.run_evaluator = task_fn('evaluator')\n    executor.run_worker = task_fn('worker')\n    executor.run_local = task_fn('local')\n    return executor\n\n  def test_run_chief(self):\n    executor = self._test_run_task_and_executor(\n        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_CHIEF))\n    executor.run()\n    self.assertEqual(1, executor.call_task['chief'])\n\n  def test_run_worker(self):\n    executor = self._test_run_task_and_executor(\n        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_WORKER))\n    executor.run()\n    self.assertEqual(1, executor.call_task['worker'])\n\n  def test_run_ps(self):\n    executor = self._test_run_task_and_executor(\n        run_config=_create_run_config_with_cluster_spec(_TF_CONFIG_FOR_PS))\n    executor.run()\n    self.assertEqual(1, executor.call_task['ps'])\n\n  def test_run_evaluator(self):\n    executor = self._test_run_task_and_executor(\n        run_config=_create_run_config_with_cluster_spec(\n            _TF_CONFIG_FOR_EVALUATOR))\n    executor.run()\n    self.assertEqual(1, executor.call_task['evaluator'])\n\n  def test_run_local(self):\n    executor = self._test_run_task_and_executor(\n        run_config=run_config_lib.RunConfig())\n    executor.run()\n    self.assertEqual(1, executor.call_task['local'])\n\n  def test_invalid_local_task(self):\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            'local': ['hos1:1'],\n        },\n        'task': {\n            'type': 'local',  # invalid task type.\n            'index': 0\n        }\n    }\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = _create_run_config_with_cluster_spec(tf_config)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n    with self.assertRaisesRegexp(ValueError, _INVALID_LOCAL_TASK_WITH_CLUSTER):\n      executor.run()\n\n  def test_unsupported_task_due_to_missing_run_task(self):\n    unsupported_task = 'alloc'\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            unsupported_task: ['hos1:1'],\n        },\n        'task': {\n            'type': unsupported_task,\n            'index': 0\n        }\n    }\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = _create_run_config_with_cluster_spec(tf_config)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):\n      executor.run()\n\n  def test_unsupported_task_due_to_not_callable(self):\n    unsupported_task = 'alloc'\n    tf_config = {\n        'cluster': {\n            run_config_lib.TaskType.CHIEF: ['host0:0'],\n            unsupported_task: ['hos1:1'],\n        },\n        'task': {\n            'type': unsupported_task,\n            'index': 0\n        }\n    }\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = _create_run_config_with_cluster_spec(tf_config)\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n    executor.run_alloc = 123  # not callable\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TO_RUN):\n      executor.run()\n\n  def test_invalid_task_type(self):\n    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)\n    mock_est.config = tf.compat.v1.test.mock.Mock()\n    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)\n    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)\n\n    mock_est.config = tf.compat.v1.test.mock.Mock()\n    mock_est.config.cluster_spec = tf.train.ClusterSpec({'1': ['dummy']})\n    mock_est.config.task_type = ''\n\n    executor = training._TrainingExecutor(mock_est, mock_train_spec,\n                                          mock_eval_spec)\n    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE):\n      executor.run()\n\n\nclass TrainAndEvaluateIntegrationTest(tf.test.TestCase):\n\n  def setUp(self):\n    self._model_dir = tempfile.mkdtemp()\n\n  def tearDown(self):\n    if self._model_dir:\n      shutil.rmtree(self._model_dir)\n\n  def _as_label(self, data_in_float):\n    return np.rint(data_in_float).astype(np.int64)\n\n  def _get_exporter(self, name, fc):\n    feature_spec = tf.compat.v1.feature_column.make_parse_example_spec(fc)\n    serving_input_receiver_fn = (\n        export_lib.build_parsing_serving_input_receiver_fn(feature_spec))\n    return exporter_lib.LatestExporter(\n        name, serving_input_receiver_fn=serving_input_receiver_fn)\n\n  def _extract_loss_and_global_step(self, event_folder):\n    \"\"\"Returns the loss and global step in last event.\"\"\"\n    event_paths = glob.glob(os.path.join(event_folder, 'events*'))\n\n    loss = None\n    global_step_count = None\n\n    for e in tf.compat.v1.train.summary_iterator(event_paths[-1]):\n      current_loss = None\n      for v in e.summary.value:\n        if v.tag == 'loss':\n          current_loss = v.simple_value\n\n      # If loss is not found, global step is meaningless.\n      if current_loss is None:\n        continue\n\n      current_global_step = e.step\n      if global_step_count is None or current_global_step > global_step_count:\n        global_step_count = current_global_step\n        loss = current_loss\n\n    return (loss, global_step_count)\n\n  def test_complete_flow_with_non_distributed_configuration(self):\n    n_classes = 3\n    input_dimension = 2\n    batch_size = 10\n\n    eval_name = 'foo'\n    exporter_name = 'saved_model_exporter'\n\n    # max_steps should be larger than save_summary_steps\n    max_steps = 10\n    save_summary_steps = 9\n\n    data = np.linspace(\n        0., n_classes - 1., batch_size * input_dimension, dtype=np.float32)\n    x_data = data.reshape(batch_size, input_dimension)\n    y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1))\n\n    # learn y = x\n    def train_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensor_slices(({\n          'x': x_data\n      }, y_data)).batch(batch_size).repeat().shuffle(1000)\n\n    def eval_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensor_slices(({\n          'x': x_data\n      }, y_data)).batch(batch_size)\n\n    def predict_input_fn():\n      return tf.compat.v1.data.Dataset.from_tensor_slices({\n          'x': x_data\n      }).batch(batch_size)\n\n    feature_columns = [\n        tf.feature_column.numeric_column('x', shape=(input_dimension,))\n    ]\n\n    est = dnn.DNNClassifier(\n        hidden_units=(2, 2),\n        feature_columns=feature_columns,\n        n_classes=n_classes,\n        config=run_config_lib.RunConfig(save_summary_steps=save_summary_steps),\n        model_dir=self._model_dir)\n\n    train_spec = training.TrainSpec(\n        input_fn=train_input_fn, max_steps=max_steps)\n\n    eval_spec = training.EvalSpec(\n        name=eval_name,\n        input_fn=eval_input_fn,\n        steps=None,\n        exporters=self._get_exporter(exporter_name, feature_columns),\n        throttle_secs=0)\n\n    training.train_and_evaluate(est, train_spec, eval_spec)\n\n    # Make sure nothing is stuck in limbo.\n    tf.compat.v1.summary.FileWriterCache.clear()\n\n    # Examine the training events. Use a range to check global step to avoid\n    # flakyness due to global step race condition.\n    training_loss, _ = self._extract_loss_and_global_step(est.model_dir)\n    self.assertIsNotNone(training_loss)\n\n    # Examine the eval events. The global step should be accurate.\n    eval_loss, eval_global_step = self._extract_loss_and_global_step(\n        event_folder=est.eval_dir(eval_name))\n    self.assertIsNotNone(eval_loss)\n    self.assertEqual(max_steps, eval_global_step)\n\n    # Examine the export folder.\n    export_dir = os.path.join(\n        os.path.join(est.model_dir, 'export'), exporter_name)\n    self.assertTrue(tf.compat.v1.gfile.Exists(export_dir))\n\n    # Examine the ckpt for predict.\n    predicted_proba = np.array([\n        x[prediction_keys.PredictionKeys.PROBABILITIES]\n        for x in est.predict(predict_input_fn)\n    ])\n    self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/util.py",
    "content": "# Copyright 2017 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utilities for Estimators.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport time\nimport tensorflow as tf\n\n# import keras 2\nversion_fn = getattr(tf.keras, 'version', None)\nif version_fn and version_fn().startswith('3.'):\n  import tf_keras  # pylint: disable=g-import-not-at-top,unused-import\n  from tf_keras.api._v1 import keras as tf_keras_v1  # pylint: disable=g-import-not-at-top,unused-import\n  from tf_keras.api._v2 import keras as tf_keras_v2  # pylint: disable=g-import-not-at-top,unused-import\nelse:\n  tf_keras = tf.keras  # Keras 2\n  tf_keras_v1 = tf.compat.v1.keras\n  tf_keras_v2 = tf.compat.v2.keras\n\nfrom tensorflow.python.util import function_utils\n\nfn_args = function_utils.fn_args\n\n# When we create a timestamped directory, there is a small chance that the\n# directory already exists because another process is also creating these\n# directories. In this case we just wait one second to get a new timestamp and\n# try again. If this fails several times in a row, then something is seriously\n# wrong.\nMAX_DIRECTORY_CREATION_ATTEMPTS = 10\n\n\ndef parse_input_fn_result(result):\n  \"\"\"Gets features, labels, and hooks from the result of an Estimator input_fn.\n\n  Args:\n    result: output of an input_fn to an estimator, which should be one of:\n      * A 'tf.data.Dataset' object: Outputs of `Dataset` object must be a tuple\n        (features, labels) with same constraints as below.\n      * A tuple (features, labels): Where `features` is a `Tensor` or a\n        dictionary of string feature name to `Tensor` and `labels` is a `Tensor`\n        or a dictionary of string label name to `Tensor`. Both `features` and\n        `labels` are consumed by `model_fn`. They should satisfy the expectation\n        of `model_fn` from inputs.\n\n  Returns:\n    Tuple of features, labels, and input_hooks, where features are as described\n    above, labels are as described above or None, and input_hooks are a list\n    of SessionRunHooks to be included when running.\n\n  Raises:\n    ValueError: if the result is a list or tuple of length != 2.\n  \"\"\"\n  input_hooks = []\n  if isinstance(result, tf.compat.v2.data.Dataset):\n    iterator = tf.compat.v1.data.make_initializable_iterator(result)\n    input_hooks.append(_DatasetInitializerHook(iterator))\n    result = iterator.get_next()\n  return parse_iterator_result(result) + (input_hooks,)\n\n\ndef parse_iterator_result(result):\n  \"\"\"Gets features, labels from result.\"\"\"\n  if isinstance(result, (list, tuple)):\n    if len(result) != 2:\n      raise ValueError(\n          'input_fn should return (features, labels) as a len 2 tuple.')\n    return result[0], result[1]\n  return result, None\n\n\nclass _DatasetInitializerHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Creates a SessionRunHook that initializes the passed iterator.\"\"\"\n\n  def __init__(self, iterator):\n    self._iterator = iterator\n\n  def begin(self):\n    self._initializer = self._iterator.initializer\n\n  def after_create_session(self, session, coord):\n    del coord\n    session.run(self._initializer)\n\n\nclass DistributedIteratorInitializerHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Creates a SessionRunHook that initializes the passed iterator.\"\"\"\n\n  def __init__(self, iterator):\n    self._iterator = iterator\n\n  def begin(self):\n    self._initializer = self._iterator.initialize()\n\n  def after_create_session(self, session, coord):\n    del coord\n    session.run(self._initializer)\n\n\nclass MultiHostDatasetInitializerHook(tf.compat.v1.train.SessionRunHook):\n  \"\"\"Creates a SessionRunHook that initializes all passed iterators.\"\"\"\n\n  def __init__(self, dataset_initializers):\n    self._initializers = dataset_initializers\n\n  def after_create_session(self, session, coord):\n    del coord\n    start = time.time()\n    session.run(self._initializers)\n    tf.compat.v1.logging.info('Initialized dataset iterators in %d seconds',\n                              time.time() - start)\n"
  },
  {
    "path": "tensorflow_estimator/python/estimator/util_test.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Tests for util.py.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom absl.testing import parameterized\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.framework import test_util\nfrom tensorflow_estimator.python.estimator import util\n\n\n@test_util.deprecated_graph_mode_only\nclass UtilTest(tf.test.TestCase, parameterized.TestCase):\n  \"\"\"Tests for miscellaneous Estimator utils.\"\"\"\n\n  def test_parse_input_fn_result_tuple(self):\n\n    def _input_fn():\n      features = tf.constant(np.arange(100))\n      labels = tf.constant(np.arange(100, 200))\n      return features, labels\n\n    features, labels, hooks = util.parse_input_fn_result(_input_fn())\n\n    with self.cached_session() as sess:\n      vals = sess.run([features, labels])\n\n    self.assertAllEqual(vals[0], np.arange(100))\n    self.assertAllEqual(vals[1], np.arange(100, 200))\n    self.assertEqual(hooks, [])\n\n  @parameterized.named_parameters(('DatasetV1', tf.compat.v1.data.Dataset),\n                                  ('DatasetV2', tf.data.Dataset))\n  def test_parse_input_fn_result_dataset(self, dataset_class):\n\n    def _input_fn():\n      features = np.expand_dims(np.arange(100), 0)\n      labels = np.expand_dims(np.arange(100, 200), 0)\n      return dataset_class.from_tensor_slices((features, labels))\n\n    features, labels, hooks = util.parse_input_fn_result(_input_fn())\n\n    with tf.compat.v1.train.MonitoredSession(hooks=hooks) as sess:\n      vals = sess.run([features, labels])\n\n    self.assertAllEqual(vals[0], np.arange(100))\n    self.assertAllEqual(vals[1], np.arange(100, 200))\n    self.assertIsInstance(hooks[0], util._DatasetInitializerHook)\n\n  def test_parse_input_fn_result_features_only(self):\n\n    def _input_fn():\n      return tf.constant(np.arange(100))\n\n    features, labels, hooks = util.parse_input_fn_result(_input_fn())\n\n    with self.cached_session() as sess:\n      vals = sess.run([features])\n\n    self.assertAllEqual(vals[0], np.arange(100))\n    self.assertEqual(labels, None)\n    self.assertEqual(hooks, [])\n\n  @parameterized.named_parameters(('DatasetV1', tf.compat.v1.data.Dataset),\n                                  ('DatasetV2', tf.data.Dataset))\n  def test_parse_input_fn_result_features_only_dataset(self, dataset_class):\n\n    def _input_fn():\n      features = np.expand_dims(np.arange(100), 0)\n      return dataset_class.from_tensor_slices(features)\n\n    features, labels, hooks = util.parse_input_fn_result(_input_fn())\n\n    with tf.compat.v1.train.MonitoredSession(hooks=hooks) as sess:\n      vals = sess.run([features])\n\n    self.assertAllEqual(vals[0], np.arange(100))\n    self.assertEqual(labels, None)\n    self.assertIsInstance(hooks[0], util._DatasetInitializerHook)\n\n  @parameterized.named_parameters(('DatasetV1', tf.compat.v1.data.Dataset),\n                                  ('DatasetV2', tf.data.Dataset))\n  def test_parse_input_fn_result_invalid(self, dataset_class):\n\n    def _input_fn():\n      features = np.expand_dims(np.arange(100), 0)\n      labels = np.expand_dims(np.arange(100, 200), 0)\n      return dataset_class.from_tensor_slices((features, labels, labels))\n\n    with self.assertRaisesRegexp(ValueError, 'input_fn should return'):\n      util.parse_input_fn_result(_input_fn())\n\n\nif __name__ == '__main__':\n  tf.test.main()\n"
  },
  {
    "path": "tensorflow_estimator/tools/pip_package/BUILD",
    "content": "package(default_visibility = [\"//tensorflow_estimator:internal\"])\n\n# Description:\n#  Tools for building the TensorFlow pip package.\n\nCOMMON_PIP_DEPS = [\n    \"//tensorflow_estimator\",\n    # Need to include testing libraries in pip package so our pip\n    # release tests can run. (see py_test rule in estimator.bzl for more context).\n    # Essentially, everything needed to run the test (except the test file itself)\n    # must be contained in the pip package since we strip away all deps.\n    \"//tensorflow_estimator/python/estimator:dnn_testing_utils\",\n    \"//tensorflow_estimator/python/estimator:dnn_testing_utils_v1\",\n    \"//tensorflow_estimator/python/estimator:linear_testing_utils\",\n    \"//tensorflow_estimator/python/estimator:linear_testing_utils_v1\",\n]\n\nsh_binary(\n    name = \"build_pip_package\",\n    srcs = [\"build_pip_package.sh\"],\n    data = COMMON_PIP_DEPS,\n)\n"
  },
  {
    "path": "tensorflow_estimator/tools/pip_package/build_pip_package.sh",
    "content": "#!/usr/bin/env bash\n# Copyright 2015 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\nset -e\n\nfunction is_absolute {\n  [[ \"$1\" = /* ]] || [[ \"$1\" =~ ^[a-zA-Z]:[/\\\\].* ]]\n}\n\nfunction real_path() {\n  is_absolute \"$1\" && echo \"$1\" || echo \"$PWD/${1#./}\"\n}\n\nfunction prepare_src() {\n  TMPDIR=\"$1\"\n\n  mkdir -p \"$TMPDIR\"\n  echo $(date) : \"=== Preparing sources in dir: ${TMPDIR}\"\n\n  if [ ! -d bazel-bin/tensorflow_estimator ]; then\n    echo \"Could not find bazel-bin.  Did you run from the root of the build tree?\"\n    exit 1\n  fi\n  cp -r \"bazel-bin/tensorflow_estimator/tools/pip_package/build_pip_package.runfiles/org_tensorflow_estimator/tensorflow_estimator\" \"$TMPDIR\"\n  cp tensorflow_estimator/tools/pip_package/setup.py \"$TMPDIR\"\n\n  # Verifies all expected files are in pip.\n  # Creates init files in all directory in pip.\n  python tensorflow_estimator/tools/pip_package/create_pip_helper.py --pip-root \"${TMPDIR}/tensorflow_estimator/\" --bazel-root \"./tensorflow_estimator\"\n}\n\nfunction build_wheel() {\n  if [ $# -lt 2 ] ; then\n    echo \"No src and dest dir provided\"\n    exit 1\n  fi\n\n  TMPDIR=\"$1\"\n  DEST=\"$2\"\n  PROJECT_NAME=\"$3\"\n\n  pushd ${TMPDIR} > /dev/null\n  echo $(date) : \"=== Building wheel\"\n  \"${PYTHON_BIN_PATH:-python}\" setup.py bdist_wheel --universal --project_name $PROJECT_NAME\n  mkdir -p ${DEST}\n  cp dist/* ${DEST}\n  popd > /dev/null\n  echo $(date) : \"=== Output wheel file is in: ${DEST}\"\n}\n\nfunction usage() {\n  echo \"Usage:\"\n  echo \"$0 [--src srcdir] [--dst dstdir] [options]\"\n  echo \"$0 dstdir [options]\"\n  echo \"\"\n  echo \"    --src                 prepare sources in srcdir\"\n  echo \"                              will use temporary dir if not specified\"\n  echo \"\"\n  echo \"    --dst                 build wheel in dstdir\"\n  echo \"                              if dstdir is not set do not build, only prepare sources\"\n  echo \"\"\n  echo \"  Options:\"\n  echo \"    --project_name <name> set project name to name\"\n  echo \"    --nightly             build tensorflow_estimator nightly\"\n  echo \"\"\n  exit 1\n}\n\nfunction main() {\n  NIGHTLY_BUILD=0\n  PROJECT_NAME=\"\"\n  SRCDIR=\"\"\n  DSTDIR=\"\"\n  CLEANSRC=1\n\n  while true; do\n    if [[ -z \"$1\" ]]; then\n      break\n    elif [[ \"$1\" == \"--help\" ]]; then\n      usage\n      exit 1\n    elif [[ \"$1\" == \"--nightly\" ]]; then\n      NIGHTLY_BUILD=1\n    elif [[ \"$1\" == \"--project_name\" ]]; then\n      shift\n      if [[ -z \"$1\" ]]; then\n        break\n      fi\n      PROJECT_NAME=\"$1\"\n    elif [[ \"$1\" == \"--src\" ]]; then\n      shift\n      if [[ -z \"$1\" ]]; then\n        break\n      fi\n      SRCDIR=\"$(real_path $1)\"\n      CLEANSRC=0\n    elif [[ \"$1\" == \"--dst\" ]]; then\n      shift\n      if [[ -z \"$1\" ]]; then\n        break\n      fi\n      DSTDIR=\"$(real_path $1)\"\n    else\n      DSTDIR=\"$(real_path $1)\"\n    fi\n    shift\n  done\n\n  if [[ -z ${PROJECT_NAME} ]]; then\n    PROJECT_NAME=\"tensorflow_estimator\"\n    if [[ ${NIGHTLY_BUILD} == \"1\" ]]; then\n      PROJECT_NAME=\"tf_estimator_nightly\"\n    fi\n  fi\n\n  if [[ -z \"$DSTDIR\" ]] && [[ -z \"$SRCDIR\" ]]; then\n    echo \"No destination dir provided\"\n    usage\n    exit 1\n  fi\n\n  if [[ -z \"$SRCDIR\" ]]; then\n    # make temp srcdir if none set\n    SRCDIR=\"$(mktemp -d -t tmp.XXXXXXXXXX)\"\n  fi\n\n  prepare_src \"$SRCDIR\"\n\n  if [[ -z \"$DSTDIR\" ]]; then\n      # only want to prepare sources\n      exit\n  fi\n\n  build_wheel \"$SRCDIR\" \"$DSTDIR\" \"$PROJECT_NAME\"\n\n  if [[ $CLEANSRC -ne 0 ]]; then\n    rm -rf \"${TMPDIR}\"\n  fi\n}\n\nmain \"$@\"\n"
  },
  {
    "path": "tensorflow_estimator/tools/pip_package/create_pip_helper.py",
    "content": "# Copyright 2019 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"Utils to help build and verify pip package for TensorFlow Estimator.\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport argparse\nimport fnmatch\nimport os\n\nPIP_EXCLUDED_FILES = frozenset([\n    'tensorflow_estimator/python/estimator/canned/optimizers_test_v2.py',\n    'tensorflow_estimator/python/estimator/canned/dnn_test_fc_v2.py',\n    'tensorflow_estimator/python/estimator/canned/dnn_test_fc_v1.py',\n    'tensorflow_estimator/python/estimator/canned/v1/dnn_estimator_test_v1.py',\n    'tensorflow_estimator/python/estimator/canned/v1/linear_test_v1.py',\n    'tensorflow_estimator/python/estimator/canned/v1/dnn_linear_combined_estimator_test_v1.py',\n    'tensorflow_estimator/python/estimator/canned/v1/dnn_linear_combined_test_v1.py',\n    'tensorflow_estimator/python/estimator/canned/v1/baseline_estimator_test_v1.py',\n    'tensorflow_estimator/python/estimator/canned/v1/linear_estimator_test_v1.py',\n    'tensorflow_estimator/python/estimator/canned/v1/baseline_test_v1.py',\n    'tensorflow_estimator/python/estimator/canned/v1/dnn_test_fc_v1_v1.py',\n    'tensorflow_estimator/python/estimator/canned/v1/dnn_test_fc_v2_v1.py',\n    'tensorflow_estimator/python/estimator/api/extractor_wrapper.py',\n    'tensorflow_estimator/python/estimator/api/generator_wrapper.py',\n    'tensorflow_estimator/tools/pip_package/setup.py',\n    'tensorflow_estimator/tools/pip_package/create_pip_helper.py',\n])\n\n# Directories that should not have __init__.py files generated within them.\nEXCLUDED_INIT_FILE_DIRECTORIES = frozenset(['tensorflow_estimator/tools'])\n\n\nclass PipPackagingError(Exception):\n  pass\n\n\ndef create_init_files(pip_root):\n  \"\"\"Create __init__.py in pip directory tree.\n\n  These files are auto-generated by Bazel when doing typical build/test, but\n  do not get auto-generated by the pip build process. Currently, the entire\n  directory tree is just python files, so its fine to just create all of the\n  init files.\n\n  Args:\n    pip_root: Root directory of code being packaged into pip.\n\n  Returns:\n    True: contrib code is included in pip.\n  \"\"\"\n  has_contrib = False\n  for path, subdirs, _ in os.walk(pip_root):\n    has_contrib = has_contrib or '/contrib/' in path\n    for subdir in subdirs:\n      init_file_path = os.path.join(path, subdir, '__init__.py')\n      if any(excluded_path in init_file_path\n             for excluded_path in EXCLUDED_INIT_FILE_DIRECTORIES):\n        continue\n      if not os.path.exists(init_file_path):\n        # Create empty file\n        open(init_file_path, 'w').close()\n  return has_contrib\n\n\ndef verify_python_files_in_pip(pip_root, bazel_root, has_contrib):\n  \"\"\"Verifies all expected files are packaged into Pip.\n\n  Args:\n    pip_root: Root directory of code being packaged into pip.\n    bazel_root: Root directory of Estimator Bazel workspace.\n    has_contrib: Code from contrib/ should be included in pip.\n\n  Raises:\n    PipPackagingError: Missing file in pip.\n  \"\"\"\n  for path, _, files in os.walk(bazel_root):\n    if not has_contrib and '/contrib/' in path:\n      continue\n    python_files = set(fnmatch.filter(files, '*.py'))\n    python_test_files = set(fnmatch.filter(files, '*test.py'))\n    # We only care about python files in the pip package, see create_init_files.\n    files = python_files - python_test_files\n    for f in files:\n      pip_path = os.path.join(pip_root, os.path.relpath(path, bazel_root), f)\n      file_name = os.path.join(path, f)\n      path_exists = os.path.exists(pip_path)\n      file_excluded = file_name.lstrip('./') in PIP_EXCLUDED_FILES\n      if not path_exists and not file_excluded:\n        raise PipPackagingError(\n            ('Pip package missing the file %s. If this is expected, add it '\n             'to PIP_EXCLUDED_FILES in create_pip_helper.py. Otherwise, '\n             'make sure it is a build dependency of the pip package') %\n            file_name)\n      if path_exists and file_excluded:\n        raise PipPackagingError(\n            ('File in PIP_EXCLUDED_FILES included in pip. %s' % file_name))\n\n\ndef main():\n  parser = argparse.ArgumentParser()\n  parser.add_argument(\n      '--bazel-root',\n      type=str,\n      required=True,\n      help='Root directory of Estimator Bazel workspace.')\n  parser.add_argument(\n      '--pip-root',\n      type=str,\n      required=True,\n      help='Root directory of code being packaged into pip.')\n\n  args = parser.parse_args()\n  has_contrib = create_init_files(args.pip_root)\n  verify_python_files_in_pip(args.pip_root, args.bazel_root, has_contrib)\n\n\nif __name__ == '__main__':\n  main()\n"
  },
  {
    "path": "tensorflow_estimator/tools/pip_package/setup.py",
    "content": "# Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n# ==============================================================================\n\"\"\"TensorFlow Estimator.\n\nTensorFlow Estimator is a high-level API that encapsulates model training,\nevaluation, prediction, and exporting.\n\"\"\"\n\nfrom __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nimport sys\nimport setuptools\n\nDOCLINES = __doc__.split('\\n')\n\n# This version string is semver compatible, but incompatible with pip.\n# For pip, we will remove all '-' characters from this string, and use the\n# result for pip.\n_VERSION = '2.16.0'\n\nREQUIRED_PACKAGES = [\n    # We depend on TensorFlow's declared pip dependencies.\n    # Add a new dep there if one is needed.\n]\n\nproject_name = 'tensorflow_estimator'\nif '--project_name' in sys.argv:\n  project_name_idx = sys.argv.index('--project_name')\n  project_name = sys.argv[project_name_idx + 1]\n  sys.argv.remove('--project_name')\n  sys.argv.pop(project_name_idx)\n\nsetuptools.setup(\n    name=project_name,\n    version=_VERSION.replace('-', ''),\n    description=DOCLINES[0],\n    long_description='\\n'.join(DOCLINES[2:]),\n    url='https://www.tensorflow.org/',\n    download_url='https://github.com/tensorflow/estimator/tags',\n    author='Google Inc.',\n    packages=setuptools.find_packages(),\n    install_requires=REQUIRED_PACKAGES,\n    # PyPI package information.\n    # Supported Python versions\n    python_requires='>=3.7',\n    classifiers=[\n        'Development Status :: 5 - Production/Stable',\n        'Intended Audience :: Developers',\n        'Intended Audience :: Education',\n        'Intended Audience :: Science/Research',\n        'License :: OSI Approved :: Apache Software License',\n        'Programming Language :: Python :: 3',\n        'Programming Language :: Python :: 3.7',\n        'Programming Language :: Python :: 3.8',\n        'Programming Language :: Python :: 3.9',\n        'Programming Language :: Python :: 3.10',\n        'Topic :: Scientific/Engineering',\n        'Topic :: Scientific/Engineering :: Mathematics',\n        'Topic :: Scientific/Engineering :: Artificial Intelligence',\n        'Topic :: Software Development',\n        'Topic :: Software Development :: Libraries',\n        'Topic :: Software Development :: Libraries :: Python Modules',\n    ],\n    license='Apache 2.0',\n    keywords='tensorflow estimator tensor machine learning',\n)\n"
  },
  {
    "path": "third_party/py/BUILD",
    "content": ""
  },
  {
    "path": "third_party/py/BUILD.tpl",
    "content": "licenses([\"restricted\"])\n\npackage(default_visibility = [\"//visibility:public\"])\n\n# Point both runtimes to the same python binary to ensure we always\n# use the python binary specified by ./configure.py script.\nload(\"@bazel_tools//tools/python:toolchain.bzl\", \"py_runtime_pair\")\n\npy_runtime(\n    name = \"py2_runtime\",\n    interpreter_path = \"%{PYTHON_BIN_PATH}\",\n    python_version = \"PY2\",\n)\n\npy_runtime(\n    name = \"py3_runtime\",\n    interpreter_path = \"%{PYTHON_BIN_PATH}\",\n    python_version = \"PY3\",\n)\n\npy_runtime_pair(\n    name = \"py_runtime_pair\",\n    py2_runtime = \":py2_runtime\",\n    py3_runtime = \":py3_runtime\",\n)\n\ntoolchain(\n    name = \"py_toolchain\",\n    toolchain = \":py_runtime_pair\",\n    toolchain_type = \"@bazel_tools//tools/python:toolchain_type\",\n)\n"
  },
  {
    "path": "third_party/py/python_configure.bzl",
    "content": "\"\"\"Repository rule for Python autoconfiguration.\n\n`python_configure` depends on the following environment variables:\n\n  * `PYTHON_BIN_PATH`: location of python binary.\n\"\"\"\n\n_PYTHON_BIN_PATH = \"PYTHON_BIN_PATH\"\n\ndef _tpl(repository_ctx, tpl, substitutions = {}, out = None):\n    if not out:\n        out = tpl\n    repository_ctx.template(\n        out,\n        Label(\"//third_party/py:%s.tpl\" % tpl),\n        substitutions,\n    )\n\ndef _fail(msg):\n    \"\"\"Output failure message when auto configuration fails.\"\"\"\n    red = \"\\033[0;31m\"\n    no_color = \"\\033[0m\"\n    fail(\"%sPython Configuration Error:%s %s\\n\" % (red, no_color, msg))\n\ndef _get_python_bin(repository_ctx):\n    \"\"\"Gets the python bin path.\"\"\"\n    python_bin = repository_ctx.os.environ.get(_PYTHON_BIN_PATH)\n    if python_bin != None:\n        return python_bin\n    python_bin_path = repository_ctx.which(\"python\")\n    if python_bin_path != None:\n        return str(python_bin_path)\n    _fail(\"Cannot find python in PATH, please make sure \" +\n          \"python is installed and add its directory in PATH, or --define \" +\n          \"%s='/something/else'.\\nPATH=%s\" % (\n              _PYTHON_BIN_PATH,\n              repository_ctx.os.environ.get(\"PATH\", \"\"),\n          ))\n\ndef _create_local_python_repository(repository_ctx):\n    \"\"\"Creates the repository containing files set up to build with Python.\"\"\"\n    python_bin = _get_python_bin(repository_ctx)\n    _tpl(repository_ctx, \"BUILD\", {\n        \"%{PYTHON_BIN_PATH}\": python_bin,\n    })\n\ndef _python_autoconf_impl(repository_ctx):\n    \"\"\"Implementation of the python_autoconf repository rule.\"\"\"\n    _create_local_python_repository(repository_ctx)\n\npython_configure = repository_rule(\n    implementation = _python_autoconf_impl,\n    environ = [\n        _PYTHON_BIN_PATH,\n    ],\n)\n\"\"\"Detects and configures the local Python toolchain.\n\nAdd the following to your WORKSPACE FILE:\n\n```python\nload(\"//third_party/py:python_configure.bzl\", \"python_configure\")\n\npython_configure(name = \"local_config_py_toolchain\")\n\nregister_toolchains(\"@local_config_py_toolchain//:py_toolchain\")\n```\n\nArgs:\n  name: A unique name for this workspace rule.\n\"\"\"\n"
  }
]